KL Perda com uma unidade Gaussiana

10

Estou implementando um VAE e notei duas implementações diferentes on-line da divergência simplificada univariada de KL gaussiana. A divergência original conforme aqui é Se assumirmos que nosso prior é uma unidade gaussiana, ou seja, e , isso simplifica para E aqui é onde está minha confusão. Embora eu tenha encontrado alguns repositórios obscuros do github com a implementação acima, o que eu acho mais usado é: μ2=0σ2=1KLloss=-log(σ1)+σ 2 1 +μ 2 1

KLloss=log(σ2σ1)+σ12+(μ1μ2)22σ2212
μ2=0σ2=1
KLloss=log(σ1)+σ12+μ12212
KLloss=12(2log(σ1)σ12μ12+1)

=12(log(σ1)σ1μ12+1)
Por exemplo, no tutorial oficial do autoencoder do Keras . Minha pergunta é então, o que estou perdendo entre esses dois? A principal diferença é descartar o fator 2 no termo do log e não esquadrar a variação. Analiticamente, usei o último com sucesso, pelo que vale a pena. Agradecemos antecipadamente por qualquer ajuda!
groovyDragon
fonte

Respostas:

7

Observe que, substituindo por na última equação, você recupera o anterior (por exemplo, ). Me levando a pensar que no primeiro caso o codificador é usado para prever a variação, enquanto no segundo é usado para prever o desvio padrão.σ1σ12log(σ1)σ12log(σ1)σ12

Ambas as formulações são equivalentes e o objetivo é inalterado.

F. Evlangeli
fonte
Eu não acho que isso possa ser equivalente. Sim, ambos são minimizados quando para zero e unit . No entanto, na equação original (com a variação), a penalidade por afastar da unidade é muito maior do que na segunda equação (com base no desvio padrão). A penalidade para variações em é a mesma para ambos, e o erro de reconstrução seria o mesmo; portanto, o uso da segunda versão altera drasticamente a importância relativa das partidas de da unidade. o que estou perdendo? μσσμσ
TheBamf 13/01
0

Eu acredito que a resposta é mais simples. No VAE, as pessoas geralmente usam uma distribuição normal multivariada, que possui matriz de covariância vez de variância . Isso parece confuso em um pedaço de código, mas tem a forma desejada.Σσ2

Aqui você pode encontrar a derivação de uma divergência de KL para distribuições normais multivariadas: Derivando a perda de divergência de KL para VAEs

Dmitry Grebenyuk
fonte