Python 从图像文件名和相应的字符串标签生成TF数据集

Python 从图像文件名和相应的字符串标签生成TF数据集,python,tensorflow,keras,Python,Tensorflow,Keras,直到几个月前,我用来训练神经网络(图像分类器)的数据都以“keras格式”存储在一个桶中,每个图像都位于一个文件夹中,该文件夹对应于图像类名: top_dir/ class1/ image1.png image2.png class2/ image3.png image4.png 为了构建数据集,我做了以下工作: list_ds = tf.data.Dataset.list_files("top_dir

直到几个月前,我用来训练神经网络(图像分类器)的数据都以“keras格式”存储在一个桶中,每个图像都位于一个文件夹中,该文件夹对应于图像类名:

top_dir/
    class1/
        image1.png
        image2.png

    class2/
        image3.png
        image4.png
为了构建数据集,我做了以下工作:

list_ds = tf.data.Dataset.list_files("top_dir/")

def decode_jpeg_and_label(filename: str):
    
    image_bytes = tf.io.read_file(filename)
    image = tf.image.decode_png(image_bytes, channels=3)
    image = tf.image.encode_jpeg(image, format='rgb', quality=100)
    image = tf.image.decode_jpeg(image, channels=3)
    label = tf.strings.split(tf.expand_dims(filename, axis=-1), sep='/')
    label = label.values[-2]
    return image, label

dataset = list_ds.map(decode_jpeg_and_label)
但是,现在图像存储在一个平面文件夹中,我得到一个API响应,它允许我构建标签数据。其格式为:

[["top_dir/image1.png", "class1"],
["top_dir/image2.png", "class1"],
["top_dir/image3.png", "class2"],
["top_dir/image4.png", "class2"]]

如何将上面的输入转化为与上面的数据集等效的数据集?

如果您的API响应是
x
,这将起作用:

def load(file_path, label):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, size=(100, 100)) # optional
    label = tf.cast(tf.equal(label, 'class2'), tf.int32)
    return img, label

ds = tf.data.Dataset.from_tensor_slices(x).map(lambda x: load(x[0], x[1]))

next(iter(ds))
(,
)

谢谢Nicolas,这非常有用,我可以用它来做我需要的转换
(<tf.Tensor: shape=(100, 100, 3), dtype=float32, numpy=
 array([[[0.40976474, 0.47250983, 0.56270593],
         [0.4039216 , 0.4666667 , 0.5568628 ],
         [0.41176474, 0.48235297, 0.57254905],
         ...,
         [0.5620584 , 0.5812747 , 0.6775489 ],
         [0.53252923, 0.5579019 , 0.6559411 ],
         [0.5176471 , 0.5568628 , 0.6509804 ]]], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=int32, numpy=0>)