Como passar recursos extraídos usando a CNN para a RNN?

6

Tenho imagens de palavras como abaixo:
Word Image

Digamos que é uma 256x64imagem. Meu objetivo é extrair o texto da imagem, 73791096754314441539que é basicamente o que um OCR faz.
Estou tentando construir um modelo que possa reconhecer palavras de imagens.
Quando estou dizendo palavra, pode ser uma das seguintes:

  1. Qualquer palavra do dicionário, palavra que não seja do dicionário
  2. az, AZ, caracteres especiais, incluindo spaces

Criei um modelo (snippet por causa das políticas da empresa) no tensorflow, conforme abaixo:

inputs = tf.placeholder(tf.float32, [common.BATCH_SIZE, common.OUTPUT_SHAPE[1], common.OUTPUT_SHAPE[0], 1])
# Here we use sparse_placeholder that will generate a
# SparseTensor required by ctc_loss op.
targets = tf.sparse_placeholder(tf.int32)

# 1d array of size [batch_size]
seq_len = tf.placeholder(tf.int32, [common.BATCH_SIZE])

model = tf.layers.conv2d(inputs, 64, (3,3),strides=(1, 1), padding='same', name='c1')
model = tf.layers.max_pooling2d(model, (3,3), strides=(2,2), padding='same', name='m1')
model = tf.layers.conv2d(model, 128,(3,3), strides=(1, 1), padding='same', name='c2')
model = tf.layers.max_pooling2d(model, (3,3),strides=(2,2), padding='same', name='m2')
model = tf.transpose(model, [3,0,1,2])
shape = model.get_shape().as_list()
model = tf.reshape(model, [shape[0],-1,shape[2]*shape[3]])

cell = tf.nn.rnn_cell.LSTMCell(common.num_hidden, state_is_tuple=True)
cell = tf.nn.rnn_cell.DropoutWrapper(cell, input_keep_prob=0.5, output_keep_prob=0.5)
stack = tf.nn.rnn_cell.MultiRNNCell([cell]*common.num_layers, state_is_tuple=True)

outputs, _ = tf.nn.dynamic_rnn(cell, model, seq_len, dtype=tf.float32,time_major=True)

Minha abordagem atual é usar a entrada de entrada de uma imagem de palavra, passá-la por um CNNrecurso de extração de imagem de alto nível, converter os recursos de imagem em dados sequenciais, como abaixo, em
[[a1,b1,c1],[a2,b2,c2],[a3,b3,c3]] -> [[a1,a2,a3],[b1,b2,b3],[c1,c2,c3]]
seguida, passá-lo através de um RNN (LSTM ou BiLSTM) e, em seguida, usar CTC(Connectionist Temporal Loss) para encontre a rede de perda e trem.
Não estou obtendo resultados conforme o esperado, queria saber se:

  1. Existe outra maneira melhor de executar esta tarefa
  2. Se estiver convertendo recursos para sequenciar corretamente
  3. Qualquer trabalho de pesquisa onde algo assim seja feito.
lordzuko
fonte

Respostas:

3

1 e 2. Você está na direção certa, precisa extrair os recursos usando uma CNN e, em vez de prever a classe, deseja remodelar a última camada de recursos e alimentá-la diretamente na RNN.

Algumas coisas para prestar atenção:

  • Com uma CNN bastante superficial, você não está aproveitando a extração de recursos de nível superior que essas arquiteturas podem oferecer. Se todas as suas imagens forem tão simples quanto o exemplo que você mostrou, são adequadas.
  • Se você está considerando uma CNN maior, junto com a RNN, há um número substancial de parâmetros a serem treinados. Para isso, você precisa de muitos dados e muitos recursos computacionais (GPUs muito fortes ou tempo).
  • Para que você obtenha o melhor dos dois, sugiro incorporar uma CNN pré-treinada em seu modelo (e apenas ajustar as últimas camadas). Esse modelo pré-treinado pode até ser treinado em imagens genéricas (por exemplo, ImageNet) e aumentará substancialmente o desempenho da CNN sem custo computacional. Você pode treinar as últimas camadas desta CNN em conjunto com a RNN.

3. Este é um bom exemplo do que você está tentando fazer. Eles basicamente tentam reconhecer texto de fotografias de rua, entre outras coisas, com a mesma metodologia que você descreve.

Metodologias semelhantes podem ser encontradas em outros domínios de pesquisa, como classificação de imagens com vários rótulos , rotulagem de sequências , reconhecimento de expressões faciais etc.

Djib2011
fonte
Um problema / gargalo com a abordagem que estou usando é que, como os modelos da CNN têm tamanho fixo de imagem de entrada, no caso de eu ter uma palavra mais longa, tenho que diminuir font_sizeou o image dimensionque afetará minha precisão. O que você acha ? Mas se eu só uso um BiLSTM como em ocropy, posso alimentar imagens de várias dimensões (provavelmente ainda em fase de experimentação). Qual a sua opinião sobre isso?
Lordzuko
2
Sim, o formato da entrada deve permanecer o mesmo, o que significa que é necessário diminuir a resolução, o que pode resultar em uma redução da precisão. Há outra abordagem que você poderia considerar, mas não tenho certeza se funcionaria: reconhecimento de objetos . Você pode tentar ter um sistema que faça propostas por região (de 1 número cada) e ter uma segunda CNN simples treinada no MNIST para tentar classificá-las. Em seguida, concatene todas as saídas em um único número.
Djib2011