Estou tentando fazer o tensorflow
equivalente a torch.transforms.Resize(TRAIN_IMAGE_SIZE)
, que redimensiona a menor dimensão da imagem TRAIN_IMAGE_SIZE
. Algo assim
def transforms(filename):
parts = tf.strings.split(filename, '/')
label = parts[-2]
image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image)
image = tf.image.convert_image_dtype(image, tf.float32)
# this doesn't work with Dataset.map() because image.shape=(None,None,3) from Dataset.map()
image = largest_sq_crop(image)
image = tf.image.resize(image, (256,256))
return image, label
list_ds = tf.data.Dataset.list_files('{}/*/*'.format(DATASET_PATH))
images_ds = list_ds.map(transforms).batch(4)
A resposta simples está aqui: Tensorflow: Corte a maior região quadrada central da imagem
Mas quando uso o método tf.data.Dataset.map(transforms)
, recebo shape=(None,None,3)
de dentro largest_sq_crop(image)
. O método funciona bem quando eu o chamo normalmente.
python
tensorflow2.0
Michael
fonte
fonte
EagerTensors
não estarem disponíveis,Dataset.map()
portanto a forma é desconhecida. existe uma solução alternativa?largest_sq_crop
?Respostas:
Eu encontrei a resposta. Tinha a ver com o fato de que meu método de redimensionamento funcionava bem com uma execução ágil, por exemplo,
tf.executing_eagerly()==True
mas falhava quando usado dentrodataset.map()
. Aparentemente, naquele ambiente de execuçãotf.executing_eagerly()==False
,.Meu erro foi na maneira como eu estava descompactando o formato da imagem para obter dimensões para o dimensionamento. A execução do gráfico de fluxo tensor parece não suportar o acesso à
tensor.shape
tupla.Eu estava usando dimensões de forma a jusante na minha
dataset.map()
função e lançou a seguinte exceção porque estava obtendo emNone
vez de um valor.Quando mudei para descompactar manualmente a forma
tf.shape()
, tudo funcionou bem.fonte