Depois de treinar um modelo no Tensorflow:
- Como você salva o modelo treinado?
- Como você mais tarde restaura esse modelo salvo?
python
tensorflow
machine-learning
model
mathetes
fonte
fonte
Respostas:
Documentos
tutorial completo e útil -> https://www.tensorflow.org/guide/saved_model
Guia detalhado de Keras para salvar modelos -> https://www.tensorflow.org/guide/keras/save_and_serialize
Dos documentos:
Salve
Restaurar
Tensorflow 2
Ainda é beta, por isso não aconselho por enquanto. Se você ainda quiser seguir esse caminho, aqui está o
tf.saved_model
guia de usoTensorflow <2
simple_save
Muitas boas respostas, para completar, adicionarei meus 2 centavos: simple_save . Também um exemplo de código independente usando a
tf.data.Dataset
API.Python 3; Tensorflow 1.14
Restaurando:
Exemplo autônomo
Postagem original do blog
O código a seguir gera dados aleatórios para a demonstração.
Dataset
e então éIterator
. Nós obtemos o tensor gerado pelo iterador, chamadoinput_tensor
que servirá como entrada para o nosso modelo.input_tensor
: um RNN bidirecional baseado em GRU, seguido por um classificador denso. Porque porque não?softmax_cross_entropy_with_logits
otimizada comAdam
. Após 2 épocas (de 2 lotes cada), salvamos o modelo "treinado"tf.saved_model.simple_save
. Se você executar o código como está, o modelo será salvo em uma pasta chamadasimple/
no seu diretório de trabalho atual.tf.saved_model.loader.load
. Pegamos os espaços reservados e logits comgraph.get_tensor_by_name
e aIterator
operação de inicialização comgraph.get_operation_by_name
.Código:
Isso imprimirá:
fonte
tf.contrib.layers
?[n.name for n in graph2.as_graph_def().node]
. Como a documentação diz, o save simples visa simplificar a interação com a veiculação do tensorflow, esse é o objetivo dos argumentos; outras variáveis ainda são restauradas, caso contrário, a inferência não aconteceria. Basta pegar suas variáveis de interesse, como fiz no exemplo. Verifique a documentaçãoglobal_step
como argumento, se você parar e tentar retomar o treinamento, ele pensará que você é um passo. Vai estragar suas visualizações tensorboard no mínimoEstou melhorando minha resposta para adicionar mais detalhes para salvar e restaurar modelos.
Na (e depois) versão 0.11 do Tensorflow :
Salve o modelo:
Restaure o modelo:
Este e alguns casos de uso mais avançados foram explicados muito bem aqui.
Um tutorial completo rápido para salvar e restaurar modelos do Tensorflow
fonte
:0
os nomes?Na versão 0.11.0RC1 (e depois) do TensorFlow, você pode salvar e restaurar seu modelo diretamente ligando
tf.train.export_meta_graph
e detf.train.import_meta_graph
acordo com https://www.tensorflow.org/programmers_guide/meta_graph .Salve o modelo
Restaurar o modelo
fonte
<built-in function TF_Run> returned a result with an error set
tf.get_variable_scope().reuse_variables()
seguido porvar = tf.get_variable("varname")
. Isso me dá o erro: "ValueError: variável varname não existe ou não foi criado com tf.get_variable ()." Por quê? Isso não deveria ser possível?Para a versão TensorFlow <0.11.0RC1:
Os pontos de verificação salvos contêm valores para os
Variable
s no seu modelo, não o modelo / gráfico em si, o que significa que o gráfico deve ser o mesmo quando você restaurar o ponto de verificação.Aqui está um exemplo de regressão linear em que há um ciclo de treinamento que salva pontos de verificação variáveis e uma seção de avaliação que restaura as variáveis salvas em uma execução anterior e prediz a computação. Obviamente, você também pode restaurar variáveis e continuar o treinamento, se desejar.
Aqui estão os documentos para
Variable
s, que abrangem salvar e restaurar. E aqui estão os documentos para oSaver
.fonte
batch_x
precisa estar? Binário? Matriz numpy?undefined
. Você pode me dizer qual é def de FLAGS para este código. @RyanSepassiMeu ambiente: Python 3.6, Tensorflow 1.3.0
Embora tenha havido muitas soluções, a maioria delas é baseada
tf.train.Saver
. Quando carregar um.ckpt
salvo porSaver
, temos de redefinir tanto a rede tensorflow ou usar algum nome estranho e hard-lembrado, por exemplo'placehold_0:0'
,'dense/Adam/Weight:0'
. Aqui eu recomendo usartf.saved_model
, um exemplo mais simples dado abaixo, você pode aprender mais sobre Servindo um Modelo TensorFlow :Salve o modelo:
Carregue o modelo:
fonte
Existem duas partes no modelo, a definição do modelo, salva
Supervisor
comograph.pbtxt
no diretório do modelo e os valores numéricos dos tensores, salvos em arquivos de ponto de verificação comomodel.ckpt-1003418
.A definição do modelo pode ser restaurada usando
tf.import_graph_def
e os pesos são restaurados usandoSaver
.No entanto,
Saver
usa uma lista especial de retenção de variáveis anexadas ao modelo Graph, e essa coleção não é inicializada usando import_graph_def; portanto, você não pode usar as duas juntas no momento (está em nosso roteiro para corrigir). Por enquanto, você precisa usar a abordagem de Ryan Sepassi - construa manualmente um gráfico com nomes de nó idênticos e useSaver
para carregar os pesos nele.(Como alternativa, você pode cortá-lo usando
import_graph_def
, criando variáveis manualmente e usandotf.add_to_collection(tf.GraphKeys.VARIABLES, variable)
para cada variável e depois usandoSaver
)fonte
Você também pode seguir esse caminho mais fácil.
Etapa 1: inicialize todas as suas variáveis
Etapa 2: salve a sessão dentro do modelo
Saver
e salve-aEtapa 3: restaurar o modelo
Etapa 4: verifique sua variável
Durante a execução em diferentes instâncias python, use
fonte
Na maioria dos casos, salvar e restaurar do disco usando a
tf.train.Saver
é a sua melhor opção:Você também pode salvar / restaurar a própria estrutura do gráfico (consulte a documentação do MetaGraph para obter detalhes). Por padrão,
Saver
salva a estrutura do gráfico em um.meta
arquivo. Você pode ligarimport_meta_graph()
para restaurá-lo. Restaura a estrutura do gráfico e retorna umSaver
que você pode usar para restaurar o estado do modelo:No entanto, há casos em que você precisa de algo muito mais rápido. Por exemplo, se você implementar uma parada antecipada, deseje salvar os pontos de verificação sempre que o modelo melhorar durante o treinamento (conforme medido no conjunto de validação); se não houver progresso por algum tempo, será necessário reverter para o melhor modelo. Se você salvar o modelo em disco toda vez que ele melhorar, ele reduzirá tremendamente o treinamento. O truque é salvar os estados das variáveis na memória e restaurá-los mais tarde:
Uma explicação rápida: quando você cria uma variável
X
, o TensorFlow cria automaticamente uma operação de atribuiçãoX/Assign
para definir o valor inicial da variável. Em vez de criar espaços reservados e operações extras de atribuição (o que deixaria o gráfico confuso), apenas usamos essas operações existentes. A primeira entrada de cada atribuição op é uma referência à variável que ela deve inicializar, e a segunda entrada (assign_op.inputs[1]
) é o valor inicial. Portanto, para definir qualquer valor que desejarmos (em vez do valor inicial), precisamos usarfeed_dict
ae substituir o valor inicial. Sim, o TensorFlow permite que você alimente um valor para qualquer operação, não apenas para espaços reservados, portanto, isso funciona bem.fonte
Como Yaroslav disse, você pode hackear a restauração de um graph_def e ponto de verificação importando o gráfico, criando manualmente variáveis e, em seguida, usando um Saver.
Eu implementei isso para meu uso pessoal, então eu gostaria de compartilhar o código aqui.
Link: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
(Obviamente, isso é um hack e não há garantia de que os modelos salvos dessa maneira permanecerão legíveis em versões futuras do TensorFlow.)
fonte
Se for um modelo salvo internamente, basta especificar um restaurador para todas as variáveis como
e use-o para restaurar variáveis em uma sessão atual:
Para o modelo externo, você precisa especificar o mapeamento dos nomes de suas variáveis para seus nomes de variáveis. Você pode visualizar os nomes das variáveis do modelo usando o comando
O script inspect_checkpoint.py pode ser encontrado na pasta './tensorflow/python/tools' da fonte do Tensorflow.
Para especificar o mapeamento, você pode usar o meu Tensorflow-Worklab , que contém um conjunto de classes e scripts para treinar e treinar novamente modelos diferentes. Inclui um exemplo de reciclagem de modelos ResNet, localizado aqui
fonte
all_variables()
agora está obsoletoAqui está minha solução simples para os dois casos básicos que diferem se você deseja carregar o gráfico do arquivo ou compilá-lo durante o tempo de execução.
Esta resposta vale para o Tensorflow 0.12+ (incluindo 1.0).
Reconstruindo o Gráfico no Código
Salvando
Carregando
Carregando também o gráfico de um arquivo
Ao usar esta técnica, verifique se todas as suas camadas / variáveis definiram explicitamente nomes exclusivos.Caso contrário, o Tensorflow tornará os nomes únicos e eles serão diferentes dos nomes armazenados no arquivo. Não é um problema na técnica anterior, porque os nomes são "mutilados" da mesma maneira no carregamento e no salvamento.
Salvando
Carregando
fonte
global_step
variável e as médias móveis da normalização do lote são variáveis não treináveis, mas ambas definitivamente valem a pena ser salvas. Além disso, você deve distinguir mais claramente a construção do gráfico da execução da sessão, por exemploSaver(...).save()
, criará novos nós sempre que você o executar. Provavelmente não é o que você quer. E há mais ...: /Você também pode conferir exemplos no TensorFlow / skflow , que oferece métodos
save
erestore
métodos que podem ajudá-lo a gerenciar facilmente seus modelos. Possui parâmetros que você também pode controlar com que frequência deseja fazer backup do seu modelo.fonte
Se você usar tf.train.MonitoredTrainingSession como a sessão padrão, não precisará adicionar código extra para salvar / restaurar as coisas. Basta passar um nome de dir de ponto de verificação para o construtor MonitoredTrainingSession, ele usará ganchos de sessão para lidar com eles.
fonte
Todas as respostas aqui são ótimas, mas quero acrescentar duas coisas.
Primeiro, para elaborar a resposta de @ user7505159, o "./" pode ser importante para adicionar ao início do nome do arquivo que você está restaurando.
Por exemplo, você pode salvar um gráfico sem "./" no nome do arquivo, assim:
Mas, para restaurar o gráfico, você pode precisar acrescentar um "./" ao nome do arquivo:
Você nem sempre precisará do "./", mas isso pode causar problemas dependendo do ambiente e da versão do TensorFlow.
Também é necessário mencionar que isso
sess.run(tf.global_variables_initializer())
pode ser importante antes de restaurar a sessão.Se você estiver recebendo um erro sobre variáveis não inicializadas ao tentar restaurar uma sessão salva, inclua
sess.run(tf.global_variables_initializer())
antes dasaver.restore(sess, save_file)
linha. Pode poupar uma dor de cabeça.fonte
Conforme descrito na edição 6255 :
ao invés de
fonte
De acordo com a nova versão do Tensorflow,
tf.train.Checkpoint
é a maneira preferível de salvar e restaurar um modelo:Aqui está um exemplo:
Mais informações e exemplo aqui.
fonte
Para o tensorflow 2.0 , é tão simples quanto
Restaurar:
fonte
tf.keras Salvamento do modelo com
TF2.0
Vejo ótimas respostas para salvar modelos usando o TF1.x. Quero fornecer mais alguns indicadores para salvar
tensorflow.keras
modelos, o que é um pouco complicado, pois há muitas maneiras de salvar um modelo.Aqui estou fornecendo um exemplo de salvar um
tensorflow.keras
modelo namodel_path
pasta no diretório atual. Isso funciona bem com o fluxo tensor mais recente (TF2.0). Atualizarei esta descrição se houver alguma alteração no futuro próximo.Salvando e carregando o modelo inteiro
Salvando e carregando apenas pesos do modelo
Se você estiver interessado em salvar apenas pesos do modelo e depois carregar pesos para restaurar o modelo,
Salvando e restaurando usando o retorno de chamada do keras checkpoint
salvando modelo com métricas personalizadas
Salvando o modelo keras com operações personalizadas
Quando temos operações personalizadas, como no caso a seguir (
tf.tile
), precisamos criar uma função e agrupar com uma camada Lambda. Caso contrário, o modelo não pode ser salvo.Acho que abordei algumas das muitas maneiras de salvar o modelo tf.keras. No entanto, existem muitas outras maneiras. Comente abaixo se o seu caso de uso não estiver coberto acima. Obrigado!
fonte
Use tf.train.Saver para salvar um modelo, remerber, você precisará especificar a var_list, se desejar reduzir o tamanho do modelo. A lista val_ pode ser tf.trainable_variables ou tf.global_variables.
fonte
Você pode salvar as variáveis na rede usando
Para restaurar a rede para reutilização mais tarde ou em outro script, use:
Pontos importantes:
sess
deve ser o mesmo entre as execuções iniciais e posteriores (estrutura coerente).saver.restore
precisa do caminho da pasta dos arquivos salvos, não de um caminho de arquivo individual.fonte
Onde você quiser salvar o modelo,
Certifique-se de que todos os seus
tf.Variable
nomes tenham nomes, pois você poderá restaurá-los posteriormente usando os nomes deles. E onde você deseja prever,Verifique se a proteção é executada dentro da sessão correspondente. Lembre-se de que, se você usar o
tf.train.latest_checkpoint('./')
, apenas o ponto de verificação mais recente será usado.fonte
Estou na versão:
Maneira simples é
Salve :
Restaurar:
fonte
Para tensorflow-2.0
é muito simples.
SALVE
RESTAURAR
fonte
Seguindo a resposta de @Vishnuvardhan Janapati, aqui está outra maneira de salvar e recarregar o modelo com camada / métrica / perda personalizada no TensorFlow 2.0.0
Dessa forma, depois de executar esses códigos e salvar seu modelo com
tf.keras.models.save_model
oumodel.save
ou comModelCheckpoint
retorno de chamada, você poderá recarregar seu modelo sem a necessidade de objetos personalizados precisos, tão simples quantofonte
Na nova versão do tensorflow 2.0, o processo de salvar / carregar um modelo é muito mais fácil. Por causa da implementação da API Keras, uma API de alto nível para o TensorFlow.
Para salvar um modelo: Verifique a documentação para referência: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model
Para carregar um modelo:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model
fonte
Aqui está um exemplo simples usando o formato SavedModel do Tensorflow 2.0 (que é o formato recomendado, de acordo com a documentação ) para um classificador de conjunto de dados MNIST simples, usando a API funcional Keras sem muita imaginação:
O que é
serving_default
?É o nome da definição de assinatura da tag que você selecionou (nesse caso, a
serve
tag padrão foi selecionada). Além disso, aqui explica como encontrar as tags e assinaturas de um modelo usandosaved_model_cli
.Isenções de responsabilidade
Este é apenas um exemplo básico, se você deseja colocá-lo em funcionamento, mas não é de modo algum uma resposta completa - talvez eu possa atualizá-lo no futuro. Eu só queria dar um exemplo simples usando o
SavedModel
TF 2.0, porque eu não vi um, mesmo assim simples, em qualquer lugar.A resposta de @ Tom é um exemplo de SavedModel, mas não funcionará no Tensorflow 2.0, porque infelizmente existem algumas mudanças.
@ A resposta de Janishati em Vishnuvardhan diz TF 2.0, mas não é para o formato SavedModel.
fonte