Keras Unet:多类图像分割

Keras Unet:多类图像分割,keras,deep-learning,image-segmentation,Keras,Deep Learning,Image Segmentation,我最近开始学习图像分割和UNet。我试图做一个多类图像分割,其中我有7个类,输入是(256,256,3)rgb图像,输出是(256,256,1)灰度图像,其中每个强度值对应一个类。我正在做像素级的softmax。我使用稀疏分类交叉熵,以避免做一个热编码 def soft1(x): return keras.activations.softmax(x, axis = -1) def conv2d_block(input_tensor, n_filters, kernel_size = 3

我最近开始学习图像分割和UNet。我试图做一个多类图像分割,其中我有7个类,输入是(256,256,3)rgb图像,输出是(256,256,1)灰度图像,其中每个强度值对应一个类。我正在做像素级的softmax。我使用稀疏分类交叉熵,以避免做一个热编码

def soft1(x):
    return keras.activations.softmax(x, axis = -1)

def conv2d_block(input_tensor, n_filters, kernel_size = 3, batchnorm = True):

    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)


    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)

    return x

def get_unet(input_img, n_classes, n_filters = 16, dropout = 0.1, batchnorm = True):
    # Contracting Path
    c1 = conv2d_block(input_img, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
    p1 = MaxPooling2D((2, 2))(c1)
    p1 = Dropout(dropout)(p1)

    c2 = conv2d_block(p1, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
    p2 = MaxPooling2D((2, 2))(c2)
    p2 = Dropout(dropout)(p2)

    c3 = conv2d_block(p2, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
    p3 = MaxPooling2D((2, 2))(c3)
    p3 = Dropout(dropout)(p3)

    c4 = conv2d_block(p3, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
    p4 = MaxPooling2D((2, 2))(c4)
    p4 = Dropout(dropout)(p4)

    c5 = conv2d_block(p4, n_filters = n_filters * 16, kernel_size = 3, batchnorm = batchnorm)

    # Expansive Path
    u6 = Conv2DTranspose(n_filters * 8, (3, 3), strides = (2, 2), padding = 'same')(c5)
    u6 = concatenate([u6, c4])
    u6 = Dropout(dropout)(u6)
    c6 = conv2d_block(u6, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)

    u7 = Conv2DTranspose(n_filters * 4, (3, 3), strides = (2, 2), padding = 'same')(c6)
    u7 = concatenate([u7, c3])
    u7 = Dropout(dropout)(u7)
    c7 = conv2d_block(u7, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)

    u8 = Conv2DTranspose(n_filters * 2, (3, 3), strides = (2, 2), padding = 'same')(c7)
    u8 = concatenate([u8, c2])
    u8 = Dropout(dropout)(u8)
    c8 = conv2d_block(u8, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)

    u9 = Conv2DTranspose(n_filters * 1, (3, 3), strides = (2, 2), padding = 'same')(c8)
    u9 = concatenate([u9, c1])
    u9 = Dropout(dropout)(u9)
    c9 = conv2d_block(u9, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)

    outputs = Conv2D(n_classes, (1, 1))(c9)
    outputs = Reshape((image_height*image_width, 1, n_classes), input_shape = (image_height, image_width, n_classes))(outputs)
    outputs = Activation(soft1)(outputs)

    model = Model(inputs=[input_img], outputs=[outputs])
    print(outputs.shape)

    return model
我的模型摘要是:

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_12 (InputLayer)           (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_211 (Conv2D)             (None, 256, 256, 16) 448         input_12[0][0]                   
__________________________________________________________________________________________________
batch_normalization_200 (BatchN (None, 256, 256, 16) 64          conv2d_211[0][0]                 
__________________________________________________________________________________________________
activation_204 (Activation)     (None, 256, 256, 16) 0           batch_normalization_200[0][0]    
__________________________________________________________________________________________________
max_pooling2d_45 (MaxPooling2D) (None, 128, 128, 16) 0           activation_204[0][0]             
__________________________________________________________________________________________________
dropout_89 (Dropout)            (None, 128, 128, 16) 0           max_pooling2d_45[0][0]           
__________________________________________________________________________________________________
conv2d_213 (Conv2D)             (None, 128, 128, 32) 4640        dropout_89[0][0]                 
__________________________________________________________________________________________________
batch_normalization_202 (BatchN (None, 128, 128, 32) 128         conv2d_213[0][0]                 
__________________________________________________________________________________________________
activation_206 (Activation)     (None, 128, 128, 32) 0           batch_normalization_202[0][0]    
__________________________________________________________________________________________________
max_pooling2d_46 (MaxPooling2D) (None, 64, 64, 32)   0           activation_206[0][0]             
__________________________________________________________________________________________________
dropout_90 (Dropout)            (None, 64, 64, 32)   0           max_pooling2d_46[0][0]           
__________________________________________________________________________________________________
conv2d_215 (Conv2D)             (None, 64, 64, 64)   18496       dropout_90[0][0]                 
__________________________________________________________________________________________________
batch_normalization_204 (BatchN (None, 64, 64, 64)   256         conv2d_215[0][0]                 
__________________________________________________________________________________________________
activation_208 (Activation)     (None, 64, 64, 64)   0           batch_normalization_204[0][0]    
__________________________________________________________________________________________________
max_pooling2d_47 (MaxPooling2D) (None, 32, 32, 64)   0           activation_208[0][0]             
__________________________________________________________________________________________________
dropout_91 (Dropout)            (None, 32, 32, 64)   0           max_pooling2d_47[0][0]           
__________________________________________________________________________________________________
conv2d_217 (Conv2D)             (None, 32, 32, 128)  73856       dropout_91[0][0]                 
__________________________________________________________________________________________________
batch_normalization_206 (BatchN (None, 32, 32, 128)  512         conv2d_217[0][0]                 
__________________________________________________________________________________________________
activation_210 (Activation)     (None, 32, 32, 128)  0           batch_normalization_206[0][0]    
__________________________________________________________________________________________________
max_pooling2d_48 (MaxPooling2D) (None, 16, 16, 128)  0           activation_210[0][0]             
__________________________________________________________________________________________________
dropout_92 (Dropout)            (None, 16, 16, 128)  0           max_pooling2d_48[0][0]           
__________________________________________________________________________________________________
conv2d_219 (Conv2D)             (None, 16, 16, 256)  295168      dropout_92[0][0]                 
__________________________________________________________________________________________________
batch_normalization_208 (BatchN (None, 16, 16, 256)  1024        conv2d_219[0][0]                 
__________________________________________________________________________________________________
activation_212 (Activation)     (None, 16, 16, 256)  0           batch_normalization_208[0][0]    
__________________________________________________________________________________________________
conv2d_transpose_45 (Conv2DTran (None, 32, 32, 128)  295040      activation_212[0][0]             
__________________________________________________________________________________________________
concatenate_45 (Concatenate)    (None, 32, 32, 256)  0           conv2d_transpose_45[0][0]        
                                                                 activation_210[0][0]             
__________________________________________________________________________________________________
dropout_93 (Dropout)            (None, 32, 32, 256)  0           concatenate_45[0][0]             
__________________________________________________________________________________________________
conv2d_221 (Conv2D)             (None, 32, 32, 128)  295040      dropout_93[0][0]                 
__________________________________________________________________________________________________
batch_normalization_210 (BatchN (None, 32, 32, 128)  512         conv2d_221[0][0]                 
__________________________________________________________________________________________________
activation_214 (Activation)     (None, 32, 32, 128)  0           batch_normalization_210[0][0]    
__________________________________________________________________________________________________
conv2d_transpose_46 (Conv2DTran (None, 64, 64, 64)   73792       activation_214[0][0]             
__________________________________________________________________________________________________
concatenate_46 (Concatenate)    (None, 64, 64, 128)  0           conv2d_transpose_46[0][0]        
                                                                 activation_208[0][0]             
__________________________________________________________________________________________________
dropout_94 (Dropout)            (None, 64, 64, 128)  0           concatenate_46[0][0]             
__________________________________________________________________________________________________
conv2d_223 (Conv2D)             (None, 64, 64, 64)   73792       dropout_94[0][0]                 
__________________________________________________________________________________________________
batch_normalization_212 (BatchN (None, 64, 64, 64)   256         conv2d_223[0][0]                 
__________________________________________________________________________________________________
activation_216 (Activation)     (None, 64, 64, 64)   0           batch_normalization_212[0][0]    
__________________________________________________________________________________________________
conv2d_transpose_47 (Conv2DTran (None, 128, 128, 32) 18464       activation_216[0][0]             
__________________________________________________________________________________________________
concatenate_47 (Concatenate)    (None, 128, 128, 64) 0           conv2d_transpose_47[0][0]        
                                                                 activation_206[0][0]             
__________________________________________________________________________________________________
dropout_95 (Dropout)            (None, 128, 128, 64) 0           concatenate_47[0][0]             
__________________________________________________________________________________________________
conv2d_225 (Conv2D)             (None, 128, 128, 32) 18464       dropout_95[0][0]                 
__________________________________________________________________________________________________
batch_normalization_214 (BatchN (None, 128, 128, 32) 128         conv2d_225[0][0]                 
__________________________________________________________________________________________________
activation_218 (Activation)     (None, 128, 128, 32) 0           batch_normalization_214[0][0]    
__________________________________________________________________________________________________
conv2d_transpose_48 (Conv2DTran (None, 256, 256, 16) 4624        activation_218[0][0]             
__________________________________________________________________________________________________
concatenate_48 (Concatenate)    (None, 256, 256, 32) 0           conv2d_transpose_48[0][0]        
                                                                 activation_204[0][0]             
__________________________________________________________________________________________________
dropout_96 (Dropout)            (None, 256, 256, 32) 0           concatenate_48[0][0]             
__________________________________________________________________________________________________
conv2d_227 (Conv2D)             (None, 256, 256, 16) 4624        dropout_96[0][0]                 
__________________________________________________________________________________________________
batch_normalization_216 (BatchN (None, 256, 256, 16) 64          conv2d_227[0][0]                 
__________________________________________________________________________________________________
activation_220 (Activation)     (None, 256, 256, 16) 0           batch_normalization_216[0][0]    
__________________________________________________________________________________________________
conv2d_228 (Conv2D)             (None, 256, 256, 7)  119         activation_220[0][0]             
__________________________________________________________________________________________________
reshape_12 (Reshape)            (None, 65536, 1, 7)  0           conv2d_228[0][0]                 
__________________________________________________________________________________________________
activation_221 (Activation)     (None, 65536, 1, 7)  0           reshape_12[0][0]                 
==================================================================================================
Total params: 1,179,511
Trainable params: 1,178,039
Non-trainable params: 1,472
__________________________________________________________________________________________________
我的模型对吗?当我使用softmax时,最终输出不应该是(65536,1,1)吗?
代码正在编译,但骰子系数非常低。

您的输出实际上代表图像的像素。对于其像素,您可以将其作为
1x7
的输出。由于是sigmoid,因此此表示采用的值介于
0-1
之间。因此,当您拥有所需的类并因此进行分段时,将触发输出。如果它是
(65536,1,1)
,则不应使用分类表示,而应使用密集表示。

您的模型应以
(256256,7)
结尾

也就是说,每像素7类,并且形状应该与输出图像一致,即
(256256,1)
。这仅适用于
“稀疏的\u分类的\u交叉熵”
或自定义损失

因此,直到
conv_228
为止,该模型似乎还不错(不过没有详细查看)。
在这个卷积之后,不需要任何东西

您可以将softmax直接放置在
conv_228
中,也可以直接放置在之后


y\u列车
应为
(256256,1)

Softmax将您的输出安排在[0,1]之间,因此,在我看来,这是完全正确的。我不明白为什么(没有,65536,1,7)和(没有,65536,7)。此外,你说骰子系数低是什么意思?@ChrisTosh骰子系数是65%,我想这是因为班级不平衡。你能解释一下为什么应该是(无,65536,7)。骰子系数是计算性能的统计方法吗?分类性能如何?或者你是否检查了结果的混淆矩阵?@ChrisTosh但我的另一个问题是什么:(无,65536,1,7)第一个维度是样本数,第二个维度是像素数,最后一个维度是你的类,为什么你需要第三个维度?为什么你不重塑?(我不确定两者的区别或重要性)