Tensorflow 如何可视化带有边界框的训练模型以进行对象检测

Tensorflow 如何可视化带有边界框的训练模型以进行对象检测,tensorflow,deep-learning,object-detection,bounding-box,Tensorflow,Deep Learning,Object Detection,Bounding Box,我正在尝试绘制带有标签和预测的花朵图像,它们都有一个边界框。我使用的是预先训练好的exception模型的一些较低层 我已将输出层设置为4,因为边界框将有四个坐标: loc\u输出=keras.layers.Dense(4)(平均值) 为了简单起见,我只是使用tf.random.uniform将标签的四个坐标设置为随机数 如何使用matplotlib编写生成如下内容的函数: 这里有一种方法可以实现你想要的。但是,请注意,使用tf.random.uniform的虚拟边界框没有什么意义,默认情况下

我正在尝试绘制带有标签和预测的花朵图像,它们都有一个边界框。我使用的是预先训练好的
exception
模型的一些较低层

我已将输出层设置为4,因为边界框将有四个坐标:

loc\u输出=keras.layers.Dense(4)(平均值)

为了简单起见,我只是使用
tf.random.uniform
将标签的四个坐标设置为随机数

如何使用
matplotlib
编写生成如下内容的函数:


这里有一种方法可以实现你想要的。但是,请注意,使用
tf.random.uniform
的虚拟边界框没有什么意义,默认情况下
minval=0,maxval=1
,因此,您的虚拟坐标将在该范围内给出值,这不适用于边界框,这就是为什么在下面的演示中,我们将使用定标器值(假设为150)重新缩放坐标,希望您得到该点


培训后,准备用于推理的测试集

import numpy as np
import matplotlib.pyplot as plt

print(class_names)
test_set = test_set_raw.map(preprocess).batch(1).prefetch(1)
test_set = test_set.map(add_random_bounding_boxes)
['dandelion', 'daisy', 'tulips', 'sunflowers', 'roses']
使用
matplotlib
显示功能

for i, (X,y) in enumerate(test_set.take(1)):
    # true labels 
    true_label = y[0].numpy()
    true_bboxs = y[1].numpy()

    # model predicts 
    pred_label, pred_boxes = model.predict(X)
    pred_label = np.argmax(pred_label, axis=-1)

    # rescaling 
    dummy_true_boxes = (true_bboxs*150).astype(np.int32).clip(min=0, max=224)
    dummy_predict_boxes = (pred_boxes*150).astype(np.int32).clip(min=0, max=224)

    # Info printing 
    print('GT bbox scores: ', true_bboxs)
    print('PRED bbox scores: ', pred_boxes)
    print('After Rescaling and Clipped True BBOX: ', dummy_true_boxes)
    print('After Rescaling and Clipped Pred BBOX: ', dummy_predict_boxes)
    print('True label : {}, Predicted label {}'.format(class_names[int(true_label)], 
                                                       class_names[int(pred_label)]))

    plt.figure(figsize=(10, 10))
    plt.axis("off")
    plt.imshow(X[0])
    ax = plt.gca()

    for tbox, tcls, pbox, pcls in zip(dummy_true_boxes, true_label, dummy_predict_boxes, pred_label):
        # gt and pred labels 
        ttext = "GT: {}".format(class_names[tcls])
        ptext = "Pred: {}".format(class_names[pcls])

        # gt and pred co-ordinates 
        tx1, ty1, x2, y2 = tbox     # xmin, ymin, xmax, ymax
        tw, th = x2 - tx1, y2 - ty1  # width (w) = xmax - xmin ; height (h) = ymax - ymin

        px1, py1, x2, y2 = pbox    # xmin, ymin, xmax, ymax
        pw, ph = x2 - px1, y2 - py1  # width (w) = xmax - xmin ; height (h) = ymax - ymin


        patch = plt.Rectangle(
            [tx1, ty1], tw, th, fill=False, edgecolor=[0, 1, 0], linewidth=1
        )
        ax.add_patch(patch)
        ax.text(
            tx1,
            ty1,
            ttext,
            bbox={"facecolor": [1, 1, 1], "alpha": 0.5},
            clip_box=ax.clipbox,
            clip_on=True,
        )

        patch = plt.Rectangle(
            [px1, py1], pw, ph, fill=False, edgecolor=[1, 1, 1], linewidth=1
        )
        ax.add_patch(patch)
        ax.text(
            px1,
            py1,
            ptext,
            bbox={"facecolor": [1, 1, 1], "alpha": 0.5},
            clip_box=ax.clipbox,
            clip_on=True,
        )
    plt.show()

for i, (X,y) in enumerate(test_set.take(1)):
    # true labels 
    true_label = y[0].numpy()
    true_bboxs = y[1].numpy()

    # model predicts 
    pred_label, pred_boxes = model.predict(X)
    pred_label = np.argmax(pred_label, axis=-1)

    # rescaling 
    dummy_true_boxes = (true_bboxs*150).astype(np.int32).clip(min=0, max=224)
    dummy_predict_boxes = (pred_boxes*150).astype(np.int32).clip(min=0, max=224)

    # Info printing 
    print('GT bbox scores: ', true_bboxs)
    print('PRED bbox scores: ', pred_boxes)
    print('After Rescaling and Clipped True BBOX: ', dummy_true_boxes)
    print('After Rescaling and Clipped Pred BBOX: ', dummy_predict_boxes)
    print('True label : {}, Predicted label {}'.format(class_names[int(true_label)], 
                                                       class_names[int(pred_label)]))

    plt.figure(figsize=(10, 10))
    plt.axis("off")
    plt.imshow(X[0])
    ax = plt.gca()

    for tbox, tcls, pbox, pcls in zip(dummy_true_boxes, true_label, dummy_predict_boxes, pred_label):
        # gt and pred labels 
        ttext = "GT: {}".format(class_names[tcls])
        ptext = "Pred: {}".format(class_names[pcls])

        # gt and pred co-ordinates 
        tx1, ty1, x2, y2 = tbox     # xmin, ymin, xmax, ymax
        tw, th = x2 - tx1, y2 - ty1  # width (w) = xmax - xmin ; height (h) = ymax - ymin

        px1, py1, x2, y2 = pbox    # xmin, ymin, xmax, ymax
        pw, ph = x2 - px1, y2 - py1  # width (w) = xmax - xmin ; height (h) = ymax - ymin


        patch = plt.Rectangle(
            [tx1, ty1], tw, th, fill=False, edgecolor=[0, 1, 0], linewidth=1
        )
        ax.add_patch(patch)
        ax.text(
            tx1,
            ty1,
            ttext,
            bbox={"facecolor": [1, 1, 1], "alpha": 0.5},
            clip_box=ax.clipbox,
            clip_on=True,
        )

        patch = plt.Rectangle(
            [px1, py1], pw, ph, fill=False, edgecolor=[1, 1, 1], linewidth=1
        )
        ax.add_patch(patch)
        ax.text(
            px1,
            py1,
            ptext,
            bbox={"facecolor": [1, 1, 1], "alpha": 0.5},
            clip_box=ax.clipbox,
            clip_on=True,
        )
    plt.show()
GT bbox scores:  [[0.75246954 0.36959255 0.18266702 0.7125735 ]]
PRED bbox scores:  [[1.1755341  0.98745024 0.90438926 1.285707  ]]
After Rescaling and Clipped True BBOX:  [[112  55  27 106]]
After Rescaling and Clipped Pred BBOX:  [[176 148 135 192]]
True label : tulips, Predicted label sunflowers