O LSTM foi inventado especificamente para evitar o problema do gradiente de fuga. Supõe-se que isso seja feito com o Constant Error Carousel (CEC), que no diagrama abaixo (de Greff et al. ) Corresponde ao loop em torno da célula .
(fonte: deeplearning4j.org )
E eu entendo que essa parte pode ser vista como uma espécie de função de identidade, então a derivada é uma e o gradiente permanece constante.
O que eu não entendo é como ele não desaparece devido a outras funções de ativação? Os portões de entrada, saída e esquecimento usam um sigmóide, cuja derivada é no máximo 0,25, e g e h eram tradicionalmente tanh . Como a retropropagação daqueles que não fazem o gradiente desaparecer?
neural-networks
lstm
TheWalkingCube
fonte
fonte
Respostas:
O gradiente de fuga é melhor explicado no caso unidimensional. A multidimensional é mais complicada, mas essencialmente análoga. Você pode revisá-lo neste excelente artigo [1].
Suponha que temos um estado oculto no momento t . Se simplificarmos as coisas e removermos vieses e entradas, teremos h t = σ ( w h t - 1 ) . Então você pode mostrar queht t
O fatorado marcado com !!! é o crucial. Se o peso não for igual a 1, ele decairá para zero exponencialmente rápido emt′-tou crescerá exponencialmente rápido.
Nos LSTMs, você tem o estado da célula . O derivado não é da forma ∂ s t 'st
Aquivté a entrada para o gate de esquecer. Como você pode ver, não há fator de decomposição exponencialmente rápido envolvido. Conseqüentemente, há pelo menos um caminho em que o gradiente não desaparece. Para a derivação completa, consulte [2].
[1] Pascanu, Razvan, Tomas Mikolov e Yoshua Bengio. "Sobre a dificuldade de treinar redes neurais recorrentes." ICML (3) 28 (2013): 1310-1318.
[2] Bayer, Justin Simon. Representações da sequência de aprendizado. Diss. München, Technische Universität München, Diss., 2015, 2015.
fonte
A imagem do bloco LSTM de Greff et al. (2015) descreve uma variante que os autores chamam de LSTM de baunilha . É um pouco diferente da definição original de Hochreiter e Schmidhuber (1997). A definição original não incluía as conexões do gate e do olho mágico.
O termo Carrossel com erro constante foi usado no artigo original para indicar a conexão recorrente do estado da célula. Considere a definição original em que o estado da célula é alterado apenas por adição, quando a porta de entrada é aberta. O gradiente do estado da célula em relação ao estado da célula em uma etapa anterior é zero.
O erro ainda pode entrar no CEC através da porta de saída e da função de ativação. A função de ativação reduz um pouco a magnitude do erro antes de ser adicionado ao CEC. O CEC é o único local em que o erro pode fluir inalterado. Novamente, quando a porta de entrada é aberta, o erro sai através da porta de entrada, função de ativação e transformação afim, reduzindo a magnitude do erro.
Portanto, o erro é reduzido quando é retropropagado por meio de uma camada LSTM, mas somente quando entra e sai do CEC. O importante é que ele não mude no CEC, independentemente da distância percorrida. Isso resolve o problema na RNN básica de que cada etapa do tempo aplica uma transformação afim e não linearidade, significando que quanto maior a distância do tempo entre a entrada e a saída, menor o erro.
fonte
http://www.felixgers.de/papers/phd.pdf Consulte as seções 2.2 e 3.2.2, onde é explicada a parte do erro truncado. Eles não propagam o erro se vazar na memória da célula (ou seja, se houver uma porta de entrada fechada / ativada), mas eles atualizam os pesos da porta com base no erro apenas naquele instante. Mais tarde, é zerado durante a propagação posterior. Isso é meio que um hack, mas o motivo é que o fluxo de erros ao longo dos portões diminui com o tempo.
fonte
Gostaria de acrescentar alguns detalhes à resposta aceita, porque acho que é um pouco mais sutil e a nuance pode não ser óbvia para alguém que está aprendendo sobre RNNs pela primeira vez.
For the vanilla RNN, there is no set of weights which can be learned such thatwσ′(wht′−k)≈1
e.g. In the 1D case, supposeht′−k=1 . The function wσ′(w∗1) achieves a maximum of 0.224 at w=1.5434 . This means the gradient will decay as, (0.224)t′−t
fonte