TensorFlow, por que há três arquivos após salvar o modelo?

113

Depois de ler os documentos , salvei um modelo no TensorFlow, aqui está meu código de demonstração:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

mas depois disso, descobri que há 3 arquivos

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

E não posso restaurar o modelo restaurando o model.ckptarquivo, uma vez que tal arquivo não existe. Aqui está meu código

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Então, por que existem 3 arquivos?

Indo à minha maneira
fonte
2
Você descobriu como resolver isso? Como posso carregar o modelo novamente (usando Keras)?
rajkiran

Respostas:

116

Experimente isto:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

O método de salvamento do TensorFlow salva três tipos de arquivos porque armazena a estrutura do gráfico separadamente dos valores das variáveis . O .metaarquivo descreve a estrutura do gráfico salva, então você precisa importá-lo antes de restaurar o ponto de verificação (caso contrário, ele não sabe a quais variáveis ​​os valores do ponto de verificação salvos correspondem).

Como alternativa, você pode fazer o seguinte:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

Mesmo que não haja nenhum arquivo nomeado model.ckpt, você ainda se refere ao ponto de verificação salvo com esse nome ao restaurá-lo. Do saver.pycódigo-fonte :

Os usuários só precisam interagir com o prefixo especificado pelo usuário ... em vez de qualquer nome de caminho físico.

TK Bartel
fonte
1
então o .index e o .data não são usados? Quando esses 2 arquivos são usados, então?
ajfbiw.s
26
@ ajfbiw.s .meta armazena a estrutura do gráfico, .data armazena os valores de cada variável no gráfico, .index identifica o ponto de verificação. Portanto, no exemplo acima: import_meta_graph usa .meta, e saver.restore usa .data e .index
TK Bartel
Ah eu vejo. Obrigado.
ajfbiw.s
1
Alguma chance de você salvar o modelo com uma versão diferente do TensorFlow da que está usando para carregá-lo? ( github.com/tensorflow/tensorflow/issues/5639 )
TK Bartel
5
Alguém sabe o que isso 00000e os 00001números significam? em variables.data-?????-of-?????arquivo
Ivan Talalaev
55
  • meta arquivo : descreve a estrutura gráfico guardado, inclui GraphDef, SaverDef, e assim por diante; em seguida tf.train.import_meta_graph('/tmp/model.ckpt.meta'), aplique , irá restaurar Savere Graph.

  • arquivo de índice : é uma tabela imutável string-string (tensorflow :: table :: Table). Cada chave é o nome de um tensor e seu valor é um BundleEntryProto serializado. Cada BundleEntryProto descreve os metadados de um tensor: qual dos arquivos de "dados" contém o conteúdo de um tensor, o deslocamento para esse arquivo, soma de verificação, alguns dados auxiliares, etc.

  • arquivo de dados : é a coleção TensorBundle, salve os valores de todas as variáveis.

Guangcong Liu
fonte
Eu tenho o arquivo pb que tenho para classificação de imagens. Posso usá-lo para classificação de vídeo em tempo real?
Você pode me informar, usando o Keras 2, como carrego o modelo se ele está salvo como 3 arquivos?
rajkiran
5

Estou restaurando embeddings de palavras treinados de Word2Vec treinadas do tutorial Word2Vec.

Caso você tenha criado vários pontos de verificação:

por exemplo, os arquivos criados se parecem com este

model.ckpt-55695.data-00000-of-00001

model.ckpt-55695.index

model.ckpt-55695.meta

tente isso

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

ao chamar restore_session ():

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")
Steven Wong
fonte
O que significa "00000-of-00001" em "model.ckpt-55695.data-00000-of-00001"?
hafiz031
0

Se você treinou um CNN com abandono, por exemplo, você poderia fazer isto:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
happy_sisyphus
fonte