Quais são exatamente os mecanismos de atenção?

23

Mecanismos de atenção têm sido utilizados em vários trabalhos de Deep Learning nos últimos anos. Ilya Sutskever, chefe de pesquisa da Open AI, elogiou-os com entusiasmo: https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

Eugenio Culurciello, da Universidade de Purdue, afirmou que RNNs e LSTMs deveriam ser abandonados em favor de redes neurais puramente baseadas na atenção:

https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

Isso parece um exagero, mas é inegável que modelos puramente baseados em atenção se saíram muito bem em tarefas de modelagem de sequência: todos sabemos sobre o artigo apropriadamente chamado do Google, é tudo que você precisa de atenção

No entanto, o que exatamente são modelos baseados em atenção? Ainda não encontrei uma explicação clara de tais modelos. Suponha que eu queira prever os novos valores de uma série temporal multivariada, dados seus valores históricos. É bastante claro como fazer isso com um RNN com células LSTM. Como eu faria o mesmo com um modelo baseado em atenção?

DeltaIV
fonte

Respostas:

20

Atenção é um método para agregar um conjunto de vetores vi em apenas um vetor, geralmente por meio do vetor de pesquisa u . Geralmente, vi são as entradas para o modelo ou os estados ocultos das etapas de tempo anteriores ou os estados ocultos um nível abaixo (no caso de LSTMs empilhados).

O resultado costuma ser chamado de vetor de contexto c , pois contém o contexto relevante para o atual intervalo de tempo.

Esse vetor de contexto adicional c é alimentado no RNN / LSTM (pode ser simplesmente concatenado com a entrada original). Portanto, o contexto pode ser usado para ajudar na previsão.

A maneira mais simples de fazer isso é a computação probabilidade vector p=softmax(VTu) e c=ipivi , onde V é a concatenação de todos os anteriores vi . Um vetor de pesquisa comumu é o estado oculto atualht .

Existem muitas variações nisso, e você pode tornar as coisas tão complicadas quanto desejar. Por exemplo, em vez de usar viTu como os logits, pode-se escolher f(vi,u) em vez disso, ondef é uma rede neural arbitrária.

Um mecanismo de atenção comum para modelos de sequência a sequência usa p=softmax(qTtanh(W1vi+W2ht)) , onde v são os estados ocultos do codificador e ht é o estado oculto atual do decodificador. q e ambos os W s são parâmetros.

Alguns trabalhos que mostram diferentes variações na idéia de atenção:

As redes de ponteiros prestam atenção às entradas de referência para resolver problemas de otimização combinatória.

As redes de entidades recorrentes mantêm estados de memória separados para diferentes entidades (pessoas / objetos) durante a leitura de texto e atualizam o estado correto da memória usando atenção.

Os modelos de transformadores também fazem uso extensivo de atenção. A sua formulação de atenção é ligeiramente mais geral e também envolve vectores chave ki : os pesos atenção p são, na verdade, calculado entre as chaves e a pesquisa, e o contexto é então construído com a vi .


Aqui está uma rápida implementação de uma forma de atenção, embora eu não possa garantir a correção além do fato de ter passado em alguns testes simples.

RNN básico:

def rnn(inputs_split):
    bias = tf.get_variable('bias', shape = [hidden_dim, 1])
    weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
    weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])

    hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
    for i, input in enumerate(inputs_split):
        input = tf.reshape(input, (batch, in_dim, 1))
        last_state = hidden_states[-1]
        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
        hidden_states.append(hidden)
    return hidden_states[-1]

Com atenção, adicionamos apenas algumas linhas antes que o novo estado oculto seja calculado:

        if len(hidden_states) > 1:
            logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
            probs = tf.nn.softmax(logits)
            probs = tf.reshape(probs, (batch, -1, 1, 1))
            context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
        else:
            context = tf.zeros_like(last_state)

        last_state = tf.concat([last_state, context], axis = 1)

        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )

o código completo

shimao
fonte
p=softmax(VTu)ic=ipivipiVTvVTv
1
zi=viTup=softmax(z)pi=eizjejz
ppi
1
sim, isso é o que eu quis dizer
Shimao
@shimao Criei uma sala de bate - papo , deixe-me saber se você está interessado em conversar (não sobre esta questão) #
DeltaIV