Python 3.x Keras上采样2D不一致行为
这是我的模型:Python 3.x Keras上采样2D不一致行为,python-3.x,tensorflow,machine-learning,neural-network,keras,Python 3.x,Tensorflow,Machine Learning,Neural Network,Keras,这是我的模型: filters = 256 kernel_size = 3 strides = 1 factor = 4 # the factor of upscaling inputLayer = Input(shape=(img_height//factor, img_width//factor, img_depth)) conv1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(inputLayer) r
filters = 256
kernel_size = 3
strides = 1
factor = 4 # the factor of upscaling
inputLayer = Input(shape=(img_height//factor, img_width//factor, img_depth))
conv1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(inputLayer)
res = Conv2D(filters, kernel_size, strides=strides, padding='same')(conv1)
act = ReLU()(res)
res = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
res_rec = Add()([conv1, res])
for i in range(15): # 16-1
res1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
act = ReLU()(res1)
res2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
res_rec = Add()([res_rec, res2])
conv2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
a = Add()([conv1, conv2])
up = UpSampling2D(size=4)(a)
outputLayer = Conv2D(filters=3,
kernel_size=1,
strides=1,
padding='same')(up)
model = Model(inputs=inputLayer, outputs=outputLayer)
model.summary()
显示:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 350, 350, 3) 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 350, 350, 256 7168 input_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 350, 350, 256 590080 conv2d_1[0][0]
__________________________________________________________________________________________________
re_lu_1 (ReLU) (None, 350, 350, 256 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 350, 350, 256 590080 re_lu_1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 350, 350, 256 0 conv2d_1[0][0]
conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 350, 350, 256 590080 add_1[0][0]
__________________________________________________________________________________________________
re_lu_2 (ReLU) (None, 350, 350, 256 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 350, 350, 256 590080 re_lu_2[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 350, 350, 256 0 add_1[0][0]
conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 350, 350, 256 590080 add_2[0][0]
__________________________________________________________________________________________________
re_lu_3 (ReLU) (None, 350, 350, 256 0 conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 350, 350, 256 590080 re_lu_3[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 350, 350, 256 0 add_2[0][0]
conv2d_7[0][0]
...... this goes on for a long time .....
__________________________________________
add_15 (Add) (None, 350, 350, 256 0 add_14[0][0]
conv2d_31[0][0]
__________________________________________________________________________________________________
conv2d_32 (Conv2D) (None, 350, 350, 256 590080 add_15[0][0]
__________________________________________________________________________________________________
re_lu_16 (ReLU) (None, 350, 350, 256 0 conv2d_32[0][0]
__________________________________________________________________________________________________
conv2d_33 (Conv2D) (None, 350, 350, 256 590080 re_lu_16[0][0]
__________________________________________________________________________________________________
add_16 (Add) (None, 350, 350, 256 0 add_15[0][0]
conv2d_33[0][0]
__________________________________________________________________________________________________
conv2d_34 (Conv2D) (None, 350, 350, 256 590080 add_16[0][0]
__________________________________________________________________________________________________
add_17 (Add) (None, 350, 350, 256 0 conv2d_1[0][0]
conv2d_34[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 1400, 1400, 2 0 add_17[0][0]
__________________________________________________________________________________________________
conv2d_35 (Conv2D) (None, 1400, 1400, 3 771 up_sampling2d_1[0][0]
==================================================================================================
Total params: 19,480,579
Trainable params: 19,480,579
Non-trainable params: 0
__________________________________________________________________________________________________
None
重要的部分就在末尾,靠近输出:
__________________________________________________________________________________________________
add_17 (Add) (None, 350, 350, 256 0 conv2d_1[0][0]
conv2d_34[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 1400, 1400, 2 0 add_17[0][0]
__________________________________________________________________________________________________
conv2d_35 (Conv2D) (None, 1400, 1400, 3 771 up_sampling2d_1[0][0]
==================================================================================================
现在,看看我在运行网络时遇到的错误:
Traceback (most recent call last):
File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 280, in <module>
setUpImages()
File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 96, in setUpImages
setUpData(trainData, testData)
File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 135, in setUpData
setUpModel(X_train, Y_train, validateTestData, trainingTestData)
File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 176, in setUpModel
train(model, X_train, Y_train, validateTestData, trainingTestData)
File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 192, in train
batch_size=32)
File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training.py", line 950, in fit
batch_size=batch_size)
File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training.py", line 787, in _standardize_user_data
exception_prefix='target')
File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training_utils.py", line 137, in standardize_input_data
str(data_shape))
ValueError: Error when checking target: expected conv2d_35 to have shape (1400, 1400, 1) but got array with shape (1400, 1400, 3)
回溯(最近一次呼叫最后一次):
文件“C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py”,第280行,在
设置图像()
文件“C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py”,第96行,在setUpImages中
设置数据(列车数据、测试数据)
文件“C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py”,第135行,在setUpData中
设置模型(X_序列、Y_序列、validateTestData、trainingTestData)
文件“C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py”,第176行,在setUpModel中
列车(型号、X_列车、Y_列车、validateTestData、trainingTestData)
文件“C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py”,第192行,列车中
批次(单位尺寸=32)
文件“C:\Users\payne\Anaconda3\envs\ml gpu\lib\site packages\keras\engine\training.py”,第950行,适合
批次大小=批次大小)
文件“C:\Users\payne\Anaconda3\envs\ml gpu\lib\site packages\keras\engine\training.py”,第787行,在用户数据中
异常(前缀='target')
标准化输入数据中的文件“C:\Users\payne\Anaconda3\envs\ml gpu\lib\site packages\keras\engine\training\u utils.py”,第137行
str(数据形状))
ValueError:检查目标时出错:预期conv2d_35具有形状(1400,1400,1),但获得具有形状(1400,1400,3)的数组
为什么我的上一次卷积期望有一个(1400,1400,1)
张量,却得到一个(1400,1400,3)
张量,而摘要说上采样2D
应该返回一个(1400,1400,2)
张量
为了澄清一点上下文:这应该是一个网络,它接收350x350x3图像并输出1400x1400x3图像。因此,显然错误消息与
conv2d\u 35
实体没有具体关系,而是与我的丢失函数链接的网络的最后一个实体
由于我选择了稀疏的、分类的、交叉熵作为损失函数,所以它需要一个一维向量
将损失设置为mean_squared_error
修复了它