이전 Convolutional Layer의 델타 항과 가중치를 고려하여 Convolutional Layer의 델타 항을 어떻게 계산합니까?


10

두 개의 회선 레이어 (c1, c2)와 두 개의 숨겨진 레이어 (c1, c2)로 인공 신경망을 훈련하려고합니다. 표준 역 전파 접근법을 사용하고 있습니다. 역방향 패스에서 이전 레이어의 오류, 이전 레이어의 가중치 및 현재 레이어의 활성화 기능에 대한 활성화의 기울기를 기반으로 레이어의 델타 항을 계산합니다. 보다 구체적으로 l 계층의 델타는 다음과 같습니다.

delta(l) = (w(l+1)' * delta(l+1)) * grad_f_a(l)

일반 레이어에 연결되는 c2의 그라디언트를 계산할 수 있습니다. h1의 가중치에 델타를 곱합니다. 그런 다음 해당 행렬을 c2 출력 형식으로 재구성하고 활성화 함수의 기울기로 곱하면 완료됩니다.

이제 c2의 델타 항이 있습니다-이것은 크기의 4D 행렬입니다 (featureMapSize, featureMapSize, filterNum, patternNum). 또한 c2의 가중치를 가지며 크기는 3D 매트릭스 (filterSize, filterSize, filterNum)입니다.

이 두 항과 c1 활성화의 기울기로 c1의 델타를 계산하고 싶습니다.

간단히 말해 :

이전 컨볼 루션 레이어의 델타 항과 해당 레이어의 가중치가 주어지면 컨 벌루 셔널 레이어의 델타 항은 어떻게 계산합니까?

답변:


6

먼저 다차원으로 쉽게 전송할 수있는 1 차원 배열 (입력)의 단순화를 위해 아래의 컨볼 루션 레이어에 대한 오류를 도출합니다.

여기서는 길이 의 이 번째 전환 의 입력이라고 가정합니다 . 층은, 가중치의 커널 사이즈 각 가중치를 나타내는 출력은 . 따라서 우리는 다음과 같이 쓸 수있다 (0으로부터의 합산에 주목) : 여기서 및 활성화 기능 (예 : 시그 모이 드). 이를 통해 이제 우리는 의해 주어진 일부 오류 함수 와 컨 볼루 셔널 레이어 (이전 레이어 중 하나)의 오류 함수를 고려할 수 있습니다yl1Nl1mwwixl

xil=a=0m1waya+il1
yil=f(xil)fEE/yil. 이제 이전 계층의 가중치 중 하나에서 오류의 종속성을 확인하려고합니다. 여기서 가 발생 하는 모든 표현식에 대해 있습니다. 또한 마지막 항은 첫 번째 방정식에서 볼 수있는 이라는 사실에서 비롯된 것입니다. 그래디언트를 계산하려면 첫 번째 항을 알아야합니다.
Ewa=a=0NmExilxilwa=a=0NmEwayi+al1

waNmxilwa=yi+al1
Exil=Eyilyilxil=Eyilxilf(xil)
다시 첫 번째 항은 이전 층에서의 에러 인 비선형 활성 기능.f

필요한 모든 엔티티가 있으면 오류를 계산하고이를 귀중한 계층으로 효율적으로 다시 전파 할 수 있습니다. 마지막 단계는 -s를 사용 하여 -s를 기록 할 때 쉽게 이해할 수 있습니다 . 전치 무게 maxtrix (를 말합니다 ).

δal1=Eyil1=a=0m1Exialxialyil1=a=0m1Exialwaflipped
xilyil1flippedT

따라서 다음 레이어에서 오류를 계산할 수 있습니다 (현재 벡터 표기법으로).

δl=(wl)Tδl+1f(xl)

이것은 컨 벌루 셔널 및 서브 샘플링 계층이된다 : 여기서 작업은 최대 풀링 계층을 통해 오류를 전파합니다.

δl=upsample((wl)Tδl+1)f(xl)
upsample

저를 추가하거나 수정하십시오!

참조를 위해 :

http://ufldl.stanford.edu/tutorial/supervised/ConvolutionalNeuralNetwork/ http://andrew.gibiansky.com/blog/machine-learning/convolutional-neural-networks/

C ++ 구현 (설치 요구 사항 없음) : https://github.com/nyanp/tiny-cnn#supported-networks

당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.