Pelo que recolhi até agora, existem várias maneiras diferentes de despejar um gráfico do TensorFlow em um arquivo e, em seguida, carregá-lo em outro programa, mas não consegui encontrar exemplos / informações claras sobre como eles funcionam. O que eu já sei é o seguinte:
- Salve as variáveis do modelo em um arquivo de checkpoint (.ckpt) usando um
tf.train.Saver()
e restaure-as mais tarde ( fonte ) - Salve um modelo em um arquivo .pb e carregue-o de volta usando
tf.train.write_graph()
etf.import_graph_def()
( fonte ) - Carregue um modelo de um arquivo .pb, treine-o novamente e despeje-o em um novo arquivo .pb usando o Bazel ( fonte )
- Congele o gráfico para salvar o gráfico e os pesos juntos ( fonte )
- Use
as_graph_def()
para salvar o modelo, e para pesos / variáveis, mapeie-os em constantes ( fonte )
No entanto, não consegui esclarecer várias dúvidas sobre esses métodos diferentes:
- Com relação aos arquivos de checkpoint, eles salvam apenas os pesos treinados de um modelo? Os arquivos de checkpoint podem ser carregados em um novo programa e usados para executar o modelo, ou eles simplesmente servem como formas de salvar os pesos em um modelo em um determinado momento / estágio?
- Em relação
tf.train.write_graph()
, os pesos / variáveis também são salvos? - Em relação ao Bazel, ele só pode salvar / carregar arquivos .pb para treinamento? Existe um comando simples do Bazel apenas para despejar um gráfico em um .pb?
- Com relação ao congelamento, um gráfico congelado pode ser carregado usando
tf.import_graph_def()
? - A demonstração do Android para TensorFlow é carregada no modelo Inception do Google a partir de um arquivo .pb. Se eu quisesse substituir meu próprio arquivo .pb, como faria para fazer isso? Eu precisaria alterar algum código / método nativo?
- Em geral, qual é exatamente a diferença entre todos esses métodos? Ou mais amplamente, qual é a diferença entre
as_graph_def()
/.ckpt/.pb?
Em suma, o que estou procurando é um método para salvar um gráfico (como em, as várias operações e outros) e seus pesos / variáveis em um arquivo, que pode então ser usado para carregar o gráfico e os pesos em outro programa , para uso (não necessariamente continuando / retreinando).
A documentação sobre este tópico não é muito direta, portanto, quaisquer respostas / informações serão muito apreciadas.
fonte
Respostas:
Existem muitas maneiras de abordar o problema de salvar um modelo no TensorFlow, o que pode torná-lo um pouco confuso. Tomando cada uma de suas subquestões sucessivamente:
Os arquivos de ponto de verificação (produzidos por exemplo, chamando
saver.save()
umtf.train.Saver
objeto) contêm apenas os pesos e quaisquer outras variáveis definidas no mesmo programa. Para usá-los em outro programa, você deve recriar a estrutura de gráfico associada (por exemplo, executando o código para criá-lo novamente ou chamandotf.import_graph_def()
), que informa ao TensorFlow o que fazer com esses pesos. Observe que a chamadasaver.save()
também produz um arquivo contendo umMetaGraphDef
, que contém um gráfico e detalhes de como associar os pesos de um ponto de verificação a esse gráfico. Veja o tutorial para mais detalhes.tf.train.write_graph()
escreve apenas a estrutura do gráfico; não os pesos.O Bazel não está relacionado à leitura ou gravação de gráficos do TensorFlow. (Talvez eu não tenha entendido sua pergunta: sinta-se à vontade para esclarecê-la em um comentário.)
Um gráfico congelado pode ser carregado usando
tf.import_graph_def()
. Nesse caso, os pesos são (normalmente) embutidos no gráfico, então você não precisa carregar um checkpoint separado.A principal mudança seria atualizar os nomes dos tensores que são alimentados no modelo e os nomes dos tensores que são buscados no modelo. Na demonstração do TensorFlow Android, isso corresponderia às strings
inputName
eoutputName
que são passadas paraTensorFlowClassifier.initializeTensorFlow()
.A
GraphDef
é a estrutura do programa, que normalmente não se altera durante o processo de formação. O ponto de verificação é um instantâneo do estado de um processo de treinamento, que normalmente muda a cada etapa do processo de treinamento. Como resultado, o TensorFlow usa diferentes formatos de armazenamento para esses tipos de dados, e a API de baixo nível oferece diferentes maneiras de salvá-los e carregá-los. Bibliotecas de nível superior, como asMetaGraphDef
bibliotecas, Keras e skflow, se baseiam nesses mecanismos para fornecer maneiras mais convenientes de salvar e restaurar um modelo inteiro.fonte
tf.train.write_graph()
e executá-lo?GraphDef
salvo portf.train.write_graph()
, você também precisa lembrar os nomes dos tensores que deseja alimentar e buscar ao executar o gráfico (item 5 acima).Você pode tentar o seguinte código:
fonte