Haskell에서 신경망 아키텍처를 구현하고 MNIST에서 사용하려고합니다.
hmatrix
선형 대수 패키지를 사용하고 있습니다. 내 교육 프레임 워크는 pipes
패키지를 사용하여 빌드 됩니다.
내 코드가 컴파일되고 충돌하지 않습니다. 그러나 문제는 레이어 크기 (예 : 1000), 미니 배치 크기 및 학습률의 특정 조합 NaN
이 계산 값을 생성한다는 것입니다. 몇 가지 검사 후 극히 작은 값 (순서 1e-100
)이 결국 활성화에 나타납니다. 그러나 그것이 일어나지 않더라도 훈련은 여전히 작동하지 않습니다. 손실이나 정확성에 대한 개선은 없습니다.
내 코드를 확인하고 다시 확인했는데 문제의 원인이 무엇인지 알 수 없었습니다.
다음은 각 계층에 대한 델타를 계산하는 역 전파 훈련입니다.
backward lf n (out,tar) das = do
let δout = tr (derivate lf (tar, out)) -- dE/dy
deltas = scanr (\(l, a') δ ->
let w = weights l
in (tr a') * (w <> δ)) δout (zip (tail $ toList n) das)
return (deltas)
lf
손실 함수이고, n
네트워크 (인 weight
행렬과 bias
각 계층 벡터), out
및 tar
네트워크와의 실제 출력 target
(요구)의 출력 및 das
각 층의 활성 유도체이다.
배치 모드에서 out
, tar
행렬 (행 벡터 출력 임)이며, das
행렬의 목록이다.
실제 그래디언트 계산은 다음과 같습니다.
grad lf (n, (i,t)) = do
-- Forward propagation: compute layers outputs and activation derivatives
let (as, as') = unzip $ runLayers n i
(out) = last as
(ds) <- backward lf n (out, t) (init as') -- Compute deltas with backpropagation
let r = fromIntegral $ rows i -- Size of minibatch
let gs = zipWith (\δ a -> tr (δ <> a)) ds (i:init as) -- Gradients for weights
return $ GradBatch ((recip r .*) <$> gs, (recip r .*) <$> squeeze <$> ds)
여기서, lf
그리고 n
, 상기와 동일하다 i
입력하고, t
(행렬로 모두 배치 형태) 목표 출력한다.
squeeze
각 행을 합산하여 행렬을 벡터로 변환합니다. 즉, ds
각 열이 미니 배치 행의 델타에 해당하는 델타 행렬 목록입니다. 따라서 편향에 대한 기울기는 모든 미니 배치에 대한 델타의 평균입니다. gs
가중치에 대한 그래디언트에 해당하는 에서도 동일 합니다.
실제 업데이트 코드는 다음과 같습니다.
move lr (n, (i,t)) (GradBatch (gs, ds)) = do
-- Update function
let update = (\(FC w b af) g δ -> FC (w + (lr).*g) (b + (lr).*δ) af)
n' = Network.fromList $ zipWith3 update (Network.toList n) gs ds
return (n', (i,t))
lr
학습률입니다. FC
레이어 생성자이며 해당 레이어 af
의 활성화 함수입니다.
경사 하강 법 알고리즘은 학습률에 대해 음수 값을 전달해야합니다. 경사 하강 법의 실제 코드 는 매개 변수화 된 정지 조건이있는 grad
및 의 구성을 둘러싼 단순한 루프 move
입니다.
마지막으로 평균 제곱 오차 손실 함수에 대한 코드는 다음과 같습니다.
mse :: (Floating a) => LossFunction a a
mse = let f (y,y') = let gamma = y'-y in gamma**2 / 2
f' (y,y') = (y'-y)
in Evaluator f f'
Evaluator
손실 함수와 그 파생물 (출력 레이어의 델타 계산 용)을 번들로 묶습니다.
나머지 코드는 GitHub : NeuralNetwork에 있습니다.
따라서 누군가가 문제에 대한 통찰력을 가지고 있거나 알고리즘을 올바르게 구현하고 있는지 확인하는 경우 감사하겠습니다.