关于TensorFlow mnist数据重塑问题

关于TensorFlow mnist数据重塑问题,tensorflow,Tensorflow,现在我正在学习TensorFlow,我想知道为什么需要numpy.swapax(0,3) 我知道结果是(1,14,14,5)意味着[15元素[145元素[145元素[5元素]]] 在颠簸之后。交换(3,0)->(5,14,14,1)和5张图片 下面是我的代码,请保存我的问题。多谢各位 #load mnist data from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_se

现在我正在学习TensorFlow,我想知道为什么需要numpy.swapax(0,3)

我知道结果是(1,14,14,5)意味着[15元素[145元素[145元素[5元素]]]

在颠簸之后。交换(3,0)->(5,14,14,1)和5张图片

下面是我的代码,请保存我的问题。多谢各位

#load mnist data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

#get only 1 image & reshape it
img = mnist.train.images[0].reshape(28,28)
plt.imshow(img, cmap='gray')

sess = tf.InteractiveSession()

#reshape image to get color = 1
img = img.reshape(-1,28,28,1)

#filter 3X3, count = 5
W1 = tf.Variable(tf.random_normal([3, 3, 1, 5], stddev=0.01))

                                                 #zero-padded USE
conv2d = tf.nn.conv2d(img, W1, strides=[1, 2, 2, 1], padding='SAME')
print(conv2d)

sess.run(tf.global_variables_initializer())

#make convoultion data
conv2d_img = conv2d.eval()

#print converted images
conv2d_img = np.swapaxes(conv2d_img, 0, 3)
for i, one_img in enumerate(conv2d_img):
plt.subplot(1,5,i+1), plt.imshow(one_img.reshape(14,14), cmap='gray')

#pooling
pool = tf.nn.max_pool(conv2d, ksize=[1, 2, 2, 1], strides=[
                    1, 2, 2, 1], padding='SAME')
print(pool)
sess.run(tf.global_variables_initializer())
pool_img = pool.eval()

#print pooling image
pool_img = np.swapaxes(pool_img, 0, 3)
for i, one_img in enumerate(pool_img):
plt.subplot(1,5,i+1), plt.imshow(one_img.reshape(7, 7), cmap='gray')

交换是必要的,因为它会更改图像通道的顺序

默认情况下,TensorFlow使用NHWC,其中C=1,因为我们有一个灰度图像

因此,您需要在数据的最后一个轴上设置通道数(灰度图像为1,RGB为3)

在代码中,您可以看到NHWC关系保持不变(5表示图像数量==批处理大小,14表示高度,14表示宽度,1表示图像通道)