Python 输出掩码数组显示为空白

Python 输出掩码数组显示为空白,python,tensorflow,keras,tensor,semantic-segmentation,Python,Tensorflow,Keras,Tensor,Semantic Segmentation,我试图从一个语义分割模型显示对单个图像的预测,这是一个掩码数组 输出为4维张量(1160160,2),并转换为图像阵列(160160,2) 然后我尝试用matplotlib显示这个图像 最终输出为空。不知道这里发生了什么 这是我的密码: import numpy as np from keras.preprocessing import image img_path = train_files[10] img = tf.keras.preprocessing.image.load_img(im

我试图从一个语义分割模型显示对单个图像的预测,这是一个掩码数组

输出为4维张量(1160160,2),并转换为图像阵列(160160,2) 然后我尝试用matplotlib显示这个图像

最终输出为空。不知道这里发生了什么

这是我的密码:

import numpy as np
from keras.preprocessing import image

img_path = train_files[10]
img = tf.keras.preprocessing.image.load_img(img_path, target_size=(160,160))
img = tf.keras.preprocessing.image.img_to_array(img, dtype=np.uint8)
input_tensor = np.expand_dims(img, 0)

val_preds = model.predict(input_tensor)
print('Shape: ',input_tensor.shape)

test_preds = model.predict(input_tensor)
test_preds
输出:

Shape:  (1, 160, 160, 3)
array([[[[ 3.514218 , -5.554067 ],
         [ 6.0965424, -6.619807 ],
         [ 5.996727 , -4.5980225],
         ...,
         [ 6.0274673, -6.033106 ],
         [ 4.853381 , -5.9041157],
         [ 5.04994  , -6.3008943]],

        [[ 6.222473 , -6.143579 ],
         [ 7.506066 , -7.0598283],
         [ 8.403559 , -5.5494676],
         ...,
         [ 6.839488 , -6.9970145],
         [ 7.473741 , -6.0447083],
         [ 6.659992 , -6.404495 ]],

        [[ 6.3930387, -6.465418 ],
         [ 8.294706 , -8.931659 ],
         [ 8.272978 , -8.10188  ],
         ...,
         [ 6.964575 , -7.5066137],
         [ 8.075083 , -6.632631 ],
         [ 6.83106  , -8.529422 ]],

        ...,

        [[ 5.9942446, -6.7861133],
         [ 7.1510434, -7.7782245],
         [ 7.5247774, -5.3231616],
         ...,
         [ 7.8443217, -7.826961 ],
         [ 8.614312 , -7.102455 ],
         [ 7.219681 , -6.6550455]],

        [[ 6.123674 , -6.6143103],
         [ 8.129679 , -7.597713 ],
         [ 4.2799473, -6.4181447],
         ...,
         [ 7.4701095, -9.676035 ],
         [ 8.42034  , -5.923585 ],
         [ 7.5446596, -9.011229 ]],

        [[ 6.1459126, -6.66231  ],
         [ 6.444136 , -7.5591726],
         [ 6.3249736, -5.059468 ],
         ...,
         [ 5.995722 , -6.515067 ],
         [ 7.4314594, -6.307902 ],
         [ 6.6631203, -6.6196494]]]], dtype=float32)
(160, 160, 1)
array([0], dtype=uint8)
在图像上显示此遮罩阵列:

pred = np.array(test_preds*255).squeeze().round()
mask = np.argmax(pred, axis=-1)            
# mask = cv2.resize(mask, (1280, 720), interpolation=cv2.INTER_NEAREST)
mask = np.expand_dims(mask, axis=-1)    
print(mask.shape)
img = keras.preprocessing.image.array_to_img(mask)
print(np.unique(img))
输出:

Shape:  (1, 160, 160, 3)
array([[[[ 3.514218 , -5.554067 ],
         [ 6.0965424, -6.619807 ],
         [ 5.996727 , -4.5980225],
         ...,
         [ 6.0274673, -6.033106 ],
         [ 4.853381 , -5.9041157],
         [ 5.04994  , -6.3008943]],

        [[ 6.222473 , -6.143579 ],
         [ 7.506066 , -7.0598283],
         [ 8.403559 , -5.5494676],
         ...,
         [ 6.839488 , -6.9970145],
         [ 7.473741 , -6.0447083],
         [ 6.659992 , -6.404495 ]],

        [[ 6.3930387, -6.465418 ],
         [ 8.294706 , -8.931659 ],
         [ 8.272978 , -8.10188  ],
         ...,
         [ 6.964575 , -7.5066137],
         [ 8.075083 , -6.632631 ],
         [ 6.83106  , -8.529422 ]],

        ...,

        [[ 5.9942446, -6.7861133],
         [ 7.1510434, -7.7782245],
         [ 7.5247774, -5.3231616],
         ...,
         [ 7.8443217, -7.826961 ],
         [ 8.614312 , -7.102455 ],
         [ 7.219681 , -6.6550455]],

        [[ 6.123674 , -6.6143103],
         [ 8.129679 , -7.597713 ],
         [ 4.2799473, -6.4181447],
         ...,
         [ 7.4701095, -9.676035 ],
         [ 8.42034  , -5.923585 ],
         [ 7.5446596, -9.011229 ]],

        [[ 6.1459126, -6.66231  ],
         [ 6.444136 , -7.5591726],
         [ 6.3249736, -5.059468 ],
         ...,
         [ 5.995722 , -6.515067 ],
         [ 7.4314594, -6.307902 ],
         [ 6.6631203, -6.6196494]]]], dtype=float32)
(160, 160, 1)
array([0], dtype=uint8)

您的输出不在
[0,1]
范围内,因此当乘以255时,您不会在
[0255]
uint8范围内结束,这可能是它无法正确显示的原因。我建议您检查输出节点上的激活功能,或者弄清楚应该如何对输出进行后期处理。