Image 如何将.jpg读入tensorflow数据集中并使用会话显示图像

Image 如何将.jpg读入tensorflow数据集中并使用会话显示图像,image,tensorflow,dataset,Image,Tensorflow,Dataset,我正在尝试通过tf.data.dataset从文件加载图像,然后使用matplotlib显示图像。一旦我了解了加载单个图像的工作原理,我还有几个文件要扩展。我不明白这里出了什么问题。是什么导致了下面的错误。如何更正此代码以便显示图像 我使用的是tensorflow 1.14 import tensorflow as tf import matplotlib.pyplot as plt filename = tf.constant(['D:/Datasets/The Oxford-IIIT Pe

我正在尝试通过tf.data.dataset从文件加载图像,然后使用matplotlib显示图像。一旦我了解了加载单个图像的工作原理,我还有几个文件要扩展。我不明白这里出了什么问题。是什么导致了下面的错误。如何更正此代码以便显示图像

我使用的是tensorflow 1.14

import tensorflow as tf
import matplotlib.pyplot as plt

filename = tf.constant(['D:/Datasets/The Oxford-IIIT Pet Dataset (Segmentation)/images/Abyssinian_1.jpg'])

dataset = tf.data.Dataset.from_tensor_slices(filename)

def format_image(image_dir):
    image = tf.read_file(image_dir)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize_image_with_pad(image, 256, 256, align_corners=True)
    return image

dataset = dataset.map(format_image)
dataset = dataset.batch(1)

iterator = dataset.make_initializable_iterator()
image = iterator.get_next()

with tf.Session() as sess:
    sess.run([iterator.initializer])
    decoded_image = sess.run([image])
    plt.imshow(decoded_image)
    plt.show()
我得到一个错误:

Traceback (most recent call last):
  File "C:/Users/g/Deeplab_custom/readinganimage.py", line 24, in <module>
    plt.imshow(decoded_image)
  File "C:\Users\s\AppData\Local\Programs\Python\Python36\lib\site-packages\matplotlib\pyplot.py", line 2699, in imshow
    None else {}), **kwargs)
  File "C:\Users\s\AppData\Local\Programs\Python\Python36\lib\site-packages\matplotlib\__init__.py", line 1810, in inner
    return func(ax, *args, **kwargs)
  File "C:\Users\s\AppData\Local\Programs\Python\Python36\lib\site-packages\matplotlib\axes\_axes.py", line 5494, in imshow
    im.set_data(X)
  File "C:\Users\s\AppData\Local\Programs\Python\Python36\lib\site-packages\matplotlib\image.py", line 638, in set_data
    raise TypeError("Invalid dimensions for image data")
TypeError: Invalid dimensions for image data

按照您当前拥有代码的方式,它返回[1,1256256,3]输出。这些尺寸[通过在sess.run、批次尺寸、高度、宽度、通道中使用方括号引入]。matplotlib不理解这一点。matplotlib需要[height,width,channels]数组

所以在你的情况下,你能做的是,如下

with tf.Session() as sess:
    sess.run([iterator.initializer])
    decoded_image = sess.run([image])
    decoded_image = sess.run(image)
    plt.imshow(decoded_image[0][0])
with tf.Session() as sess:
    sess.run([iterator.initializer])
    decoded_image = sess.run(image)
    plt.imshow(decoded_image[0])
    plt.show()
但是使用sess.run[image]而不是使用sess.runimage会引入不必要的维度。以及以下内容

with tf.Session() as sess:
    sess.run([iterator.initializer])
    decoded_image = sess.run([image])
    decoded_image = sess.run(image)
    plt.imshow(decoded_image[0][0])
with tf.Session() as sess:
    sess.run([iterator.initializer])
    decoded_image = sess.run(image)
    plt.imshow(decoded_image[0])
    plt.show()

谢谢你!我删除了批处理操作和两个方括号,以获得matplotlib将采用的256、256、3的形状。