Python 3.x Keras上采样2D不一致行为

Python 3.x Keras上采样2D不一致行为,python-3.x,tensorflow,machine-learning,neural-network,keras,Python 3.x,Tensorflow,Machine Learning,Neural Network,Keras,这是我的模型: filters = 256 kernel_size = 3 strides = 1 factor = 4 # the factor of upscaling inputLayer = Input(shape=(img_height//factor, img_width//factor, img_depth)) conv1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(inputLayer) r

这是我的模型:

filters = 256
kernel_size = 3
strides = 1
factor = 4  # the factor of upscaling

inputLayer = Input(shape=(img_height//factor, img_width//factor, img_depth))
conv1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(inputLayer)

res = Conv2D(filters, kernel_size, strides=strides, padding='same')(conv1)
act = ReLU()(res)
res = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
res_rec = Add()([conv1, res])

for i in range(15):  # 16-1
    res1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
    act = ReLU()(res1)
    res2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
    res_rec = Add()([res_rec, res2])

conv2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
a = Add()([conv1, conv2])
up = UpSampling2D(size=4)(a)
outputLayer = Conv2D(filters=3,
                     kernel_size=1,
                     strides=1,
                     padding='same')(up)

model = Model(inputs=inputLayer, outputs=outputLayer)
model.summary()
显示:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 350, 350, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 350, 350, 256 7168        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 350, 350, 256 590080      conv2d_1[0][0]                   
__________________________________________________________________________________________________
re_lu_1 (ReLU)                  (None, 350, 350, 256 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 350, 350, 256 590080      re_lu_1[0][0]                    
__________________________________________________________________________________________________
add_1 (Add)                     (None, 350, 350, 256 0           conv2d_1[0][0]                   
                                                                 conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 350, 350, 256 590080      add_1[0][0]                      
__________________________________________________________________________________________________
re_lu_2 (ReLU)                  (None, 350, 350, 256 0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 350, 350, 256 590080      re_lu_2[0][0]                    
__________________________________________________________________________________________________
add_2 (Add)                     (None, 350, 350, 256 0           add_1[0][0]                      
                                                                 conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 350, 350, 256 590080      add_2[0][0]                      
__________________________________________________________________________________________________
re_lu_3 (ReLU)                  (None, 350, 350, 256 0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 350, 350, 256 590080      re_lu_3[0][0]                    
__________________________________________________________________________________________________
add_3 (Add)                     (None, 350, 350, 256 0           add_2[0][0]                      
                                                                 conv2d_7[0][0]                   

 ...... this goes on for a long time .....



 __________________________________________
add_15 (Add)                    (None, 350, 350, 256 0           add_14[0][0]                     
                                                                 conv2d_31[0][0]                  
__________________________________________________________________________________________________
conv2d_32 (Conv2D)              (None, 350, 350, 256 590080      add_15[0][0]                     
__________________________________________________________________________________________________
re_lu_16 (ReLU)                 (None, 350, 350, 256 0           conv2d_32[0][0]                  
__________________________________________________________________________________________________
conv2d_33 (Conv2D)              (None, 350, 350, 256 590080      re_lu_16[0][0]                   
__________________________________________________________________________________________________
add_16 (Add)                    (None, 350, 350, 256 0           add_15[0][0]                     
                                                                 conv2d_33[0][0]                  
__________________________________________________________________________________________________
conv2d_34 (Conv2D)              (None, 350, 350, 256 590080      add_16[0][0]                     
__________________________________________________________________________________________________
add_17 (Add)                    (None, 350, 350, 256 0           conv2d_1[0][0]                   
                                                                 conv2d_34[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 1400, 1400, 2 0           add_17[0][0]                     
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 1400, 1400, 3 771         up_sampling2d_1[0][0]            
==================================================================================================
Total params: 19,480,579
Trainable params: 19,480,579
Non-trainable params: 0
__________________________________________________________________________________________________
None
重要的部分就在末尾,靠近输出:

__________________________________________________________________________________________________
add_17 (Add)                    (None, 350, 350, 256 0           conv2d_1[0][0]                   
                                                                 conv2d_34[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 1400, 1400, 2 0           add_17[0][0]                     
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 1400, 1400, 3 771         up_sampling2d_1[0][0]            
==================================================================================================
现在,看看我在运行网络时遇到的错误:

Traceback (most recent call last):
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 280, in <module>
    setUpImages()
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 96, in setUpImages
    setUpData(trainData, testData)
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 135, in setUpData
    setUpModel(X_train, Y_train, validateTestData, trainingTestData)
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 176, in setUpModel
    train(model, X_train, Y_train, validateTestData, trainingTestData)
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 192, in train
    batch_size=32)
  File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training.py", line 950, in fit
    batch_size=batch_size)
  File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training.py", line 787, in _standardize_user_data
    exception_prefix='target')
  File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training_utils.py", line 137, in standardize_input_data
    str(data_shape))
ValueError: Error when checking target: expected conv2d_35 to have shape (1400, 1400, 1) but got array with shape (1400, 1400, 3)
回溯(最近一次呼叫最后一次):
文件“C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py”,第280行,在
设置图像()
文件“C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py”,第96行,在setUpImages中
设置数据(列车数据、测试数据)
文件“C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py”,第135行,在setUpData中
设置模型(X_序列、Y_序列、validateTestData、trainingTestData)
文件“C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py”,第176行,在setUpModel中
列车(型号、X_列车、Y_列车、validateTestData、trainingTestData)
文件“C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py”,第192行,列车中
批次(单位尺寸=32)
文件“C:\Users\payne\Anaconda3\envs\ml gpu\lib\site packages\keras\engine\training.py”,第950行,适合
批次大小=批次大小)
文件“C:\Users\payne\Anaconda3\envs\ml gpu\lib\site packages\keras\engine\training.py”,第787行,在用户数据中
异常(前缀='target')
标准化输入数据中的文件“C:\Users\payne\Anaconda3\envs\ml gpu\lib\site packages\keras\engine\training\u utils.py”,第137行
str(数据形状))
ValueError:检查目标时出错:预期conv2d_35具有形状(1400,1400,1),但获得具有形状(1400,1400,3)的数组
为什么我的上一次卷积期望有一个
(1400,1400,1)
张量,却得到一个
(1400,1400,3)
张量,而摘要说
上采样2D
应该返回一个
(1400,1400,2)
张量


为了澄清一点上下文:这应该是一个网络,它接收350x350x3图像并输出1400x1400x3图像。

因此,显然错误消息与
conv2d\u 35
实体没有具体关系,而是与我的丢失函数链接的网络的最后一个实体

由于我选择了稀疏的、分类的、交叉熵作为损失函数,所以它需要一个一维向量

将损失设置为
mean_squared_error
修复了它