如何识别pytorch中批次的错误分类

如何识别pytorch中批次的错误分类,pytorch,imshow,Pytorch,Imshow,我有一个这样的脚本,其中使用了成批的图像 correct = 0 total = 0 incorrect_classification=[] for (i, [images, labels]) in enumerate(test_loader): images = Variable(images.view(-1, n_pixel*n_pixel)) outputs = net(images) _, predicted = torch.min(outputs.data, 1) to

我有一个这样的脚本,其中使用了成批的图像

correct = 0
total = 0
incorrect_classification=[]
for (i, [images, labels]) in enumerate(test_loader):
  images = Variable(images.view(-1, n_pixel*n_pixel))
  outputs = net(images)
  _, predicted = torch.min(outputs.data, 1)
  total += labels.size(0)                    
  correct += (predicted == labels).sum() 
print('Accuracy: %d %%' %
      (100 * correct / total))
批处理大小为10时,每个枚举返回10倍图像大小的张量。如何将所有错误的分类保存到数组中?错误的\u分类或错误的img及其概率保存到字典中,以便稍后使用can plt.imshow检查它们

如果批量大小为1,我可以使用:

if (predicted==labels).item()==0:
    incorrect_examples.append(images.numpy())
但是如果指定了批量大小(比如每批100张图像),我应该如何保存错误的分类


提前感谢您的回答。

正如@zihaozhihao的评论中所说,
images[predicted==labels]
应该完成这项工作

换句话说,您将获得索引掩码,然后使用此掩码访问所需的图像:

correct = 0
total = 0
incorrect_examples=[]
for (i, [images, labels]) in enumerate(test_loader):
    images = Variable(images.view(-1, n_pixel*n_pixel))
    outputs = net(images)
    _, predicted = torch.min(outputs.data, 1)
    total += labels.size(0)                    
    correct += (predicted == labels).sum() 
    print('Accuracy: %d %%' % (100 * correct / total))

    # if (predicted==labels).item()==0:
    #     incorrect_examples.append(images.numpy())

    idxs_mask = (predicted == labels).view(-1)
    incorrect_examples.append(images[idxs_mask].numpy()) 
视图(-1)
将展平遮罩,用于遮罩图像张量的批次通道

在循环结束时(在循环外),列表中的iten
不正确的示例
将具有形状
[批大小,n像素,n像素]
,为了方便起见,您可以通过将它们串联在一个张量中对它们进行分组:

incorrect_images = torch.cat(incorrect_examples)
# incorrect_images.size() -> (n_incorrect_images, n_pixel, n_pixel)

可能尝试使用
图像[predicted==labels]
来获取错误的图像?