단위 가우스로 KL 손실


10

VAE를 구현하고 있으며 단순화 된 단 변량 가우시안 KL 발산에 대해 온라인에서 두 가지 다른 구현을 발견했습니다. 당 원래 발산 여기가 있다 이전의 단위 가우스 인 것으로 가정하면 (예 : 및 ) 단순화됩니다. 여기에 혼란이 있습니다. 위의 구현으로 몇 가지 모호한 github repos를 찾았지만 더 일반적으로 사용되는 것은 다음과 같습니다.

케이영형에스에스=로그(σ2σ1)+σ12+(μ1μ2)22σ2212
μ2=0σ2=1
케이영형에스에스=로그(σ1)+σ12+μ12212
케이영형에스에스=12(2로그(σ1)σ12μ12+1)

=12(로그(σ1)σ1μ12+1)
예를 들어 공식 Keras 자동 인코더 자습서에서 . 내 질문은 그렇다면이 두 가지 사이에 무엇을 놓치고 있습니까? 주요 차이점은 로그 항에 대해 2의 요소를 제거하고 분산을 제곱하지 않는 것입니다. 분석적으로 나는 후자를 성공으로 사용하여 그 가치를 평가했습니다. 도움을 주셔서 감사합니다.

답변:


7

교체하여 σ1σ12 마지막 방정식에서 이전 (즉, 로그(σ1)σ12로그(σ1)σ12). 첫 번째 경우에는 인코더가 분산을 예측하는 데 사용되고 두 번째 경우에는 표준 편차를 예측하는 데 사용된다고 생각합니다.

두 제제는 동일하며 목표는 변하지 않습니다.


나는 이것이 동등한 경우가 아니라고 생각합니다. 예, 둘 다 0 일 때 최소화됩니다.μ 그리고 단위 σ. 그러나 원래 방정식 (분산을 특징으로 함)에서 이동에 대한 페널티σ단결에서 멀어지면 두 번째 방정식보다 훨씬 큽니다 (표준 편차 기준). 변화에 대한 형벌μ 둘 다 동일하고 재구성 오류는 동일하므로 두 번째 버전을 사용하면 출발의 상대적 중요성이 크게 변경됩니다. σ화합에서. 내가 무엇을 놓치고 있습니까?
TheBamf

0

나는 대답이 더 간단하다고 생각합니다. VAE에서 사람들은 일반적으로 공분산 행렬이있는 다변량 정규 분포를 사용합니다.Σ 분산 대신 σ2. 코드에서 혼란스러워 보이지만 원하는 형식이 있습니다.

다변량 정규 분포에 대한 KL 분기의 유도를 찾을 수 있습니다. VAE에 대한 KL 분기 손실 도출

당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.