Python 如何在tensorflow 2.0+中将数据集[:,28,28]转换为[:,28,28,3]?
如何在tensorflow 2.0+中将数据集[:,28,28]转换为[:,28,28,3]Python 如何在tensorflow 2.0+中将数据集[:,28,28]转换为[:,28,28,3]?,python,tensorflow,Python,Tensorflow,如何在tensorflow 2.0+中将数据集[:,28,28]转换为[:,28,28,3] import tensorflow as tf import tensorflow_datasets as tfds import numpy as np (train_data, test_data) = tfds.load("mnist", split=[tfds.Split.TRAIN, tfds.Split.TEST], as_supervised=True) def my_tran
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
(train_data, test_data) = tfds.load("mnist", split=[tfds.Split.TRAIN, tfds.Split.TEST], as_supervised=True)
def my_transform(data, label):
# data [28, 28]
# beflow is numpy mode, how to do in dataset mode?
# data = np.expand_dims(data / 255.0, axis=-1) # [28, 28, 1]
# data = np.insert(train_data, [1,1], 1, axis=3) # [28, 28, 3]
return data, label
train_data = train_data.map(my_transform)
您可以将这些带注释的numpy操作包装在tf.py\u functionmy\u转换中,然后将所有这些包装在.map中。这将在急切模式下执行,因此不会影响速度。你可以读更多。对于一个工作示例,您可以按照。我用以下方法解决了这个问题
import tensorflow_datasets as tfds
import tensorflow as tf
(train_data, test_data) = tfds.load('mnist', split=['train', 'test'], as_supervised=True)
train_data = train_data.map(lambda data, label: (tf.image.grayscale_to_rgb(tf.image.resize(data, [28,28])), label)).batch(10)
其他解决办法
import tensorflow_datasets as tfds
import tensorflow as tf
import matplotlib.pyplot as plt
(train_data, test_data) = tfds.load('mnist', split=['train', 'test'], as_supervised=True)
# method 1
# # train_data = train_data.map(lambda data, label: (tf.image.grayscale_to_rgb(data/255), label)).batch(10)
# method 2
# train_data = train_data.map(lambda data, label: (tf.concat([data/255,tf.zeros(data.shape),tf.zeros(data.shape)],-1), label)).batch(10)
# method 3
def my_transform(data, label):
#data = tf.expand_dims(data/255, -1)
data=data/255
ex_col = tf.zeros(data.shape)
data = tf.concat([data, ex_col, ex_col], -1)
return data, label
train_data = train_data.map(my_transform).batch(10)
for datas, labels in train_data:
for i in range(len(datas)):
print(datas[i], labels[i])
plt.imshow(datas[i])
plt.show()