Machine learning 评估UNET会在mxnet中产生意外的结果

Machine learning 评估UNET会在mxnet中产生意外的结果,machine-learning,mxnet,mxnet-gluon,Machine Learning,Mxnet,Mxnet Gluon,我有一个经过训练的U-Net语义分割模型,在训练阶段,损失、训练精度和测试精度看起来都不错 epoch 16, loss 0.3861, train acc 0.824, test acc 0.807, time 1.535 sec epoch 17, loss 0.4359, train acc 0.779, test acc 0.801, time 1.524 sec epoch 18, loss 0.4661, train acc 0.777, test acc 0.803, time 1

我有一个经过训练的U-Net语义分割模型,在训练阶段,损失、训练精度和测试精度看起来都不错

epoch 16, loss 0.3861, train acc 0.824, test acc 0.807, time 1.535 sec
epoch 17, loss 0.4359, train acc 0.779, test acc 0.801, time 1.524 sec
epoch 18, loss 0.4661, train acc 0.777, test acc 0.803, time 1.607 sec
epoch 19, loss 0.4031, train acc 0.789, test acc 0.838, time 1.475 sec
epoch 20, loss 0.3925, train acc 0.827, test acc 0.830, time 1.495 sec
我的问题是,当我尝试使用下面的代码评估模型时,我没有得到预期的结果。结果始终是一个数组,数组中填充了1,但只有几个0

[1. 1. **0**. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
   1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
   1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

计算片段

def load_image(img, width, height):
    data = np.transpose(img, (2, 0, 1))
    data = mx.nd.array(data) # from numpy.ndarray to mxnet ndarray.
    print(data.shape)
    # Expand shape into (B x H x W x c)
    data = data.astype('float32')
    return mx.ndarray.expand_dims(data, axis=0)

def post_process_maskB(label, img_cols, img_rows, n_classes, p=0.5):
    return (np.where(label.asnumpy().reshape(img_cols, img_rows) > p, 1, 0)).astype('uint8')

def main():
    net = UNet(channels = 3, num_class = 2)
    net.load_parameters('./checkpoints/epoch_0010_model.params', ctx=ctx)
    image_path = './data/train/image/0.png'
    
    # load an image for prediction
    testimg = cv2.imread(image_path, 1)
    imgX = load_image(testimg, img_width, img_height)
    print(imgX.shape)    
    data = imgX.astype(np.float32)

    # Run prediction
    out = net(data).argmax(axis=1)
    print(out.shape)
    mask = post_process_maskA(out, 512, 512, 2, p = 0.5)

    print(mask.shape)
网络输入为
(1,3,512,512)
,结果为
(1,512,512)

网络架构如下:

UNet(
  (input_conv): BaseConvBlock(
    (conv1): Conv2D(3 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (down_conv_0): DownSampleBlock(
    (maxPool): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
    (conv): BaseConvBlock(
      (conv1): Conv2D(64 -> 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2D(6 -> 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (down_conv_1): DownSampleBlock(
    (maxPool): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
    (conv): BaseConvBlock(
      (conv1): Conv2D(6 -> 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2D(12 -> 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (down_conv_2): DownSampleBlock(
    (maxPool): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
    (conv): BaseConvBlock(
      (conv1): Conv2D(12 -> 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2D(24 -> 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (down_conv_3): DownSampleBlock(
    (maxPool): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
    (conv): BaseConvBlock(
      (conv1): Conv2D(24 -> 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2D(48 -> 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (up_conv_0): UpSampleBlock(
    (up): Conv2DTranspose(24 -> 48, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv): BaseConvBlock(
      (conv1): Conv2D(48 -> 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2D(24 -> 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (up_conv_1): UpSampleBlock(
    (up): Conv2DTranspose(12 -> 24, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv): BaseConvBlock(
      (conv1): Conv2D(24 -> 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2D(12 -> 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (up_conv_2): UpSampleBlock(
    (up): Conv2DTranspose(6 -> 12, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv): BaseConvBlock(
      (conv1): Conv2D(12 -> 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2D(6 -> 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (up_conv_3): UpSampleBlock(
    (up): Conv2DTranspose(3 -> 6, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv): BaseConvBlock(
      (conv1): Conv2D(67 -> 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2D(3 -> 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (output_conv): Conv2D(3 -> 2, kernel_size=(1, 1), stride=(1, 1))
)
我怀疑问题出在下面的电话上

out = net(data).argmax(axis=1)
任何帮助都将不胜感激