Qual é a saída de um tf.nn.dynamic_rnn ()?

8

Não sei ao certo o que entendi da documentação oficial, que diz:

Retorna: um par (saídas, estado) em que:

outputs: O tensor de saída RNN.

Se time_major == False(default), este será um Tensor em forma: [batch_size, max_time, cell.output_size].

Se time_major == True, este será um Tensor em forma: [max_time, batch_size, cell.output_size].

Observe que, se cell.output_sizehouver uma tupla (possivelmente aninhada) de números inteiros ou objetos TensorShape, as saídas serão uma tupla com a mesma estrutura que cell.output_size, contendo tensores com formas correspondentes aos dados da forma cell.output_size.

state: O estado final. Se cell.state_size for um int, isso será modelado [batch_size, cell.state_size]. Se for um TensorShape, isso será modelado [batch_size] + cell.state_size. Se for uma tupla (possivelmente aninhada) de ints ou TensorShape, será uma tupla com as formas correspondentes. Se as células estiverem LSTMCells, o estado será uma tupla contendo um LSTMStateTuple para cada célula.

output[-1] É sempre (nos três tipos de células, isto é, RNN, GRU, LSTM) igual ao estado (segundo elemento da tupla de retorno)? Eu acho que a literatura em toda parte é liberal demais no uso do termo estado oculto. É o estado oculto nas três células a pontuação que sai (por que é chamado de oculto está além de mim, parece que o estado da célula no LSTM deve ser chamado de estado oculto porque não é exposto)?

MiloMinderbinder
fonte

Respostas:

10

Sim, a saída da célula é igual ao estado oculto. No caso de LSTM, é a parte de curto prazo da tupla (segundo elemento de LSTMStateTuple), como pode ser visto nesta figura:

LSTM

Mas tf.nn.dynamic_rnn, para , o estado retornado pode ser diferente quando a sequência é mais curta ( sequence_lengthargumento). Veja este exemplo:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])
seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val = sess.run([outputs, states], 
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})

  print(outputs_val)
  print()
  print(states_val)

Aqui, o lote de entrada contém 4 seqüências e uma delas é curta e preenchida com zeros. Ao executar, você deve algo como isto:

[[[ 0.2315362  -0.37939444 -0.625332   -0.80235624  0.2288385 ]
  [ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]]

 [[ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.9994331   0.9929737  -0.8311569  -0.99928087  0.9990415 ]
  [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]]

 [[ 0.9962312   0.99659646  0.98880637  0.99548346  0.9997809 ]
  [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]]

[[ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]
 [ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
 [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]
 [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]

... o que realmente mostra isso state == output[1]para seqüências completas e state == output[0]para a curta. Também output[1]é um vetor zero para esta sequência. O mesmo vale para células LSTM e GRU.

Portanto, stateé um tensor conveniente que mantém o último estado real da RNN, ignorando os zeros. O outputtensor mantém as saídas de todas as células, para não ignorar os zeros. Essa é a razão para retornar os dois.

Máxima
fonte
2

Cópia possível de /programming/36817596/get-last-output-of-dynamic-rnn-in-tensorflow/49705930#49705930

De qualquer forma, vamos em frente com a resposta.

Esse trecho de código pode ajudar a entender o que realmente está sendo retornado pela dynamic_rnncamada

=> Tupla de (saídas, final_output_state) .

Portanto, para uma entrada com duração máxima de sequência de passos T, as saídas são da forma [Batch_size, T, num_inputs](dado time_major= Falso; valor padrão) e contém o estado da saída em cada passo de tempo h1, h2.....hT.

E final_output_state é da forma [Batch_size,num_inputs]e possui o estado final da célula cTe o estado hTde saída de cada sequência de lotes.

Mas como o dynamic_rnnestá sendo usado, acho que os comprimentos das sequências variam para cada lote.

    import tensorflow as tf
    import numpy as np
    from tensorflow.contrib import rnn
    tf.reset_default_graph()

    # Create input data
    X = np.random.randn(2, 10, 8)

    # The second example is of length 6 
    X[1,6:] = 0
    X_lengths = [10, 6]

    cell = tf.nn.rnn_cell.LSTMCell(num_units=64, state_is_tuple=True)

    outputs, states  = tf.nn.dynamic_rnn(cell=cell,
                                         dtype=tf.float64,
                                         sequence_length=X_lengths,
                                         inputs=X)

    result = tf.contrib.learn.run_n({"outputs": outputs, "states":states},
                                    n=1,
                                    feed_dict=None)
    assert result[0]["outputs"].shape == (2, 10, 64)
    print result[0]["outputs"].shape
    print result[0]["states"].h.shape
    # the final outputs state and states returned must be equal for each      
    # sequence
    assert(result[0]["outputs"][0][-1]==result[0]["states"].h[0]).all()
    assert(result[0]["outputs"][-1][5]==result[0]["states"].h[-1]).all()
    assert(result[0]["outputs"][-1][-1]==result[0]["states"].h[-1]).all()

A afirmação final falhará, pois o estado final da 2ª sequência está no 6º passo, ie. o índice 5 e o restante das saídas de [6: 9] são todos os 0s no segundo timestep

Bhaskar Arun
fonte