Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 为什么我的图像加载到数据集中时都是白色的?_Python_Tensorflow_Tf.keras - Fatal编程技术网

Python 为什么我的图像加载到数据集中时都是白色的?

Python 为什么我的图像加载到数据集中时都是白色的?,python,tensorflow,tf.keras,Python,Tensorflow,Tf.keras,我使用图像创建了一个数据集: import tensorflow as tf import matplotlib.pyplot as plt import numpy as np dataset = tf.keras.preprocessing.image_dataset_from_directory( <directory>, label_mode=None, seed=1, subset='training', validation_sp

我使用图像创建了一个数据集:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    <directory>,
    label_mode=None,
    seed=1,
    subset='training',
    validation_split=0.1,
    image_size=(900, 900))

images = next(iter(dataset))
print(tf.shape(images))
作为输出,我得到: 使用RGB数据将输入数据剪裁到imshow的有效范围([0..1]表示浮点数,[0..255]表示整数)

以及作为输出的纯白色图像


我确信加载到数据集中的图像不是纯白色的。有人能帮我吗?

您的
数据集是一个
tf.data.dataset
,因此您可以使用此可视化功能。

以这种方式实施:

import tensorflow_datasets as tfds

ds, info = tfds.load(<your dataset name>, split='train', with_info=True)
fig = tfds.show_examples(ds, info)
将tensorflow_数据集导入为TFD
ds,info=tfds.load(,split='train',with_info=True)
图=tfds。显示示例(ds、info)

这不是对您问题的直接回答,但franky我更喜欢使用ImageDataGenerator.flow\u from\u目录,它允许您重新缩放图像,增强图像,并且生成器的输出易于使用。文档适用于您的应用程序,代码为:

data_gen=tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255,
    validation_split=.01)
train_gen=data_gen.flow_from_directory(
    directory= r'c:\your_directory',
    target_size=(900,900),
    class_mode="categorical",
    batch_size=32,
    shuffle=True,
    seed=123,    
    subset='training')
valid_gen=data_gen.flow_from_directory(
    directory= r'c:\your_directory',
    target_size=(900,900),
    class_mode="categorical",
    batch_size=32,
    shuffle=False,
    seed=123,    
    subset='validation')
tr_images, tr_labels=next (train_gen) # generate a batch of 32 training images, and labels
val_images, val_labels=next(valid_gen)
print('tr_images.shape = ', tr_images.shape)
# result will be tr_images.shape =  (32, 900, 900, 3)
image1=tr_images[0]
plt.imshow(image1)
plt.show()
# result will be showing the image
# other useful things available
class_dict=train_gen.class_indices # a dictionary where key is the text class name and value is integer label of the class
print (class_dict)
labels= train_gen.labels # a sequential list of all the generator labels
file_names= test_gen.filenames  # a sequential list of all the generator file names
# hope this helps

将它们转移到numpy,并转换到uint8。 这是我用来检查输入和输出的函数,我使用32的批量大小,但只打印其中的8个

def generate_images(test_input):
  prediction = decoder(encoder(test_input), training=True)
  plt.figure(figsize=(15, 15))
  for i in range(8):
    plt.subplot(4, 8, i*2+1)
    plt.imshow(test_input[i].numpy().astype("uint8") )
    plt.subplot(4, 8, i*2+2)
    plt.imshow(prediction[i] )
  plt.axis('off')
  plt.show() 
要从数据集调用此

for imgs, labels in train_ds.take(1):
  generate_images( imgs)

我试过使用上面提到的方法,但没有一种方法能达到我想要的效果。您可以将
插值更改为
“最近的”
默认值为
“双线性”

试试这个:

ds=tf.keras.preprocessing.image\u数据集\u来自\u目录(
目录=目录,#图像路径
....
color_mode=“rgb”,
插值=“最近”,
follow_links=False,
)
然后,要显示图像,请执行以下操作:

对于图像,在ds中添加标签:
通过
plt.imshow(图像[0])

谢谢。我试过了,但我的数据集“找不到”。另外,我真的只是想知道为什么我的图像显示为纯白色,以及如何修复它。这可能是因为matplotlib无法将tf.tensors作为图像加载。这就是为什么您应该使用tensorflow可视化功能。查看该链接上提供的教程以了解更多信息谢谢,问题是我想从图像中提取补丁,因此使用map函数,如()所示。有没有办法用发电机做到这一点?
def generate_images(test_input):
  prediction = decoder(encoder(test_input), training=True)
  plt.figure(figsize=(15, 15))
  for i in range(8):
    plt.subplot(4, 8, i*2+1)
    plt.imshow(test_input[i].numpy().astype("uint8") )
    plt.subplot(4, 8, i*2+2)
    plt.imshow(prediction[i] )
  plt.axis('off')
  plt.show() 
for imgs, labels in train_ds.take(1):
  generate_images( imgs)