Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/289.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 如何在tensorflow 2.0+中将数据集[:,28,28]转换为[:,28,28,3]?_Python_Tensorflow - Fatal编程技术网

Python 如何在tensorflow 2.0+中将数据集[:,28,28]转换为[:,28,28,3]?

Python 如何在tensorflow 2.0+中将数据集[:,28,28]转换为[:,28,28,3]?,python,tensorflow,Python,Tensorflow,如何在tensorflow 2.0+中将数据集[:,28,28]转换为[:,28,28,3] import tensorflow as tf import tensorflow_datasets as tfds import numpy as np (train_data, test_data) = tfds.load("mnist", split=[tfds.Split.TRAIN, tfds.Split.TEST], as_supervised=True) def my_tran

如何在tensorflow 2.0+中将数据集[:,28,28]转换为[:,28,28,3]

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

(train_data, test_data) = tfds.load("mnist", split=[tfds.Split.TRAIN,     tfds.Split.TEST], as_supervised=True)

def my_transform(data, label):
    # data [28, 28]
    # beflow is numpy mode, how to do in dataset mode?
    # data = np.expand_dims(data / 255.0, axis=-1)      # [28, 28, 1]
    # data = np.insert(train_data, [1,1], 1, axis=3)    # [28, 28, 3]
    return data, label

train_data = train_data.map(my_transform)

您可以将这些带注释的numpy操作包装在tf.py\u functionmy\u转换中,然后将所有这些包装在.map中。这将在急切模式下执行,因此不会影响速度。你可以读更多。对于一个工作示例,您可以按照。

我用以下方法解决了这个问题

import tensorflow_datasets as tfds
import tensorflow as tf

(train_data, test_data) = tfds.load('mnist', split=['train', 'test'], as_supervised=True)

train_data = train_data.map(lambda data, label: (tf.image.grayscale_to_rgb(tf.image.resize(data, [28,28])), label)).batch(10)
其他解决办法

import tensorflow_datasets as tfds
import tensorflow as tf
import matplotlib.pyplot as plt

(train_data, test_data) = tfds.load('mnist', split=['train', 'test'], as_supervised=True)

# method 1
# # train_data = train_data.map(lambda data, label: (tf.image.grayscale_to_rgb(data/255), label)).batch(10)

# method 2
# train_data = train_data.map(lambda data, label: (tf.concat([data/255,tf.zeros(data.shape),tf.zeros(data.shape)],-1), label)).batch(10)

# method 3
def my_transform(data, label):
    #data = tf.expand_dims(data/255, -1)
    data=data/255
    ex_col = tf.zeros(data.shape)
    data = tf.concat([data, ex_col, ex_col], -1)
    return data, label

train_data = train_data.map(my_transform).batch(10)

for datas, labels in train_data:
    for i in range(len(datas)):
        print(datas[i], labels[i])
        plt.imshow(datas[i])
        plt.show()