Python 如何使用plt.imshow和torchvision.utils.make_grid在PyTorch中生成和显示图像网格?

Python 如何使用plt.imshow和torchvision.utils.make_grid在PyTorch中生成和显示图像网格?,python,matplotlib,pytorch,imshow,torchvision,Python,Matplotlib,Pytorch,Imshow,Torchvision,我试图了解torchvision如何与mathplotlib交互以生成图像网格。生成图像并以迭代方式显示它们很容易: import torch import torchvision import matplotlib.pyplot as plt w = torch.randn(10,3,640,640) for i in range (0,10): z = w[i] plt.imshow(z.permute(1,2,0)) plt.show() 然而,在网格中显示这些

我试图了解torchvision如何与mathplotlib交互以生成图像网格。生成图像并以迭代方式显示它们很容易:

import torch
import torchvision
import matplotlib.pyplot as plt

w = torch.randn(10,3,640,640)
for i in range (0,10):
    z = w[i]
    plt.imshow(z.permute(1,2,0))
    plt.show()
然而,在网格中显示这些图像似乎并不那么简单

w = torch.randn(10,3,640,640)
grid = torchvision.utils.make_grid(w, nrow=5)
plt.imshow(grid)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-61-1601915e10f3> in <module>()
      1 w = torch.randn(10,3,640,640)
      2 grid = torchvision.utils.make_grid(w, nrow=5)
----> 3 plt.imshow(grid)

/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, hold, data, **kwargs)
   3203                         filternorm=filternorm, filterrad=filterrad,
   3204                         imlim=imlim, resample=resample, url=url, data=data,
-> 3205                         **kwargs)
   3206     finally:
   3207         ax._hold = washold

/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, *args, **kwargs)
   1853                         "the Matplotlib list!)" % (label_namer, func.__name__),
   1854                         RuntimeWarning, stacklevel=2)
-> 1855             return func(ax, *args, **kwargs)
   1856 
   1857         inner.__doc__ = _add_data_doc(inner.__doc__,

/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs)
   5485                               resample=resample, **kwargs)
   5486 
-> 5487         im.set_data(X)
   5488         im.set_alpha(alpha)
   5489         if im.get_clip_path() is None:

/anaconda3/lib/python3.6/site-packages/matplotlib/image.py in set_data(self, A)
    651         if not (self._A.ndim == 2
    652                 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
--> 653             raise TypeError("Invalid dimensions for image data")
    654 
    655         if self._A.ndim == 3:

TypeError: Invalid dimensions for image data
w=torch.randn(10,3640640)
网格=torchvision.utils.make_网格(w,nrow=5)
plt.imshow(网格)
---------------------------------------------------------------------------
TypeError回溯(最近一次调用上次)
在()
1 w=torch.randn(10,3640640)
2格线=torchvision.utils.make_格线(w,nrow=5)
---->3 plt.imshow(网格)
/imshow中的anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py(X,cmap,norm,aspect,interpolation,alpha,vmin,vmax,origin,extent,shape,filternorm,filterrad,imlim,重采样,url,hold,data,**kwargs)
3203 filternorm=filternorm,filterrad=filterrad,
3204 imlim=imlim,重采样=重采样,url=url,data=data,
->3205**夸尔格)
3206最后:
3207 ax._hold=洗旧
/内部的anaconda3/lib/python3.6/site packages/matplotlib/__init__.py(ax,*args,**kwargs)
1853年的今天,“Matplotlib列表!”%(标签名称,函数名称),
1854运行时警告,堆栈级别=2)
->1855返回函数(ax,*args,**kwargs)
1856
1857内部.\uuuuu文档\uuuuu=\u添加数据\u文档(内部.\uuuuu文档,
/imshow中的anaconda3/lib/python3.6/site-packages/matplotlib/axes//u axes.py(self、X、cmap、norm、aspect、interpolation、alpha、vmin、vmax、origin、extent、shape、filternorm、filterrad、imlim、重采样、url、**kwargs)
5485重采样=重采样,**kwargs)
5486
->5487 im.set_数据(X)
5488 im.set_α(α)
5489如果im.get\u clip\u path()为无:
/set_数据中的anaconda3/lib/python3.6/site-packages/matplotlib/image.py(self,A)
651如果不是(自身)_A.ndim==2
652或self.\u A.ndim==3和self.\u A.shape[-1]在[3,4]中:
-->653 raise TypeError(“图像数据的尺寸无效”)
654
655如果self.\u A.ndim==3:
TypeError:图像数据的维度无效
尽管PyTorch的文档表明w是正确的形状,但Python说它不是。所以我试着排列张量的指数:

    w = torch.randn(10,3,640,640)
    grid = torchvision.utils.make_grid(w.permute(0,2,3,1), nrow=5)
    plt.imshow(grid)
---------------------------------------------------------------------------
    RuntimeError                              Traceback (most recent call last)
    <ipython-input-62-6f2dc6313e29> in <module>()
          1 w = torch.randn(10,3,640,640)
    ----> 2 grid = torchvision.utils.make_grid(w.permute(0,2,3,1), nrow=5)
          3 plt.imshow(grid)

    /anaconda3/lib/python3.6/site-packages/torchvision-0.2.1-py3.6.egg/torchvision/utils.py in make_grid(tensor, nrow, padding, normalize, range, scale_each, pad_value)
         83             grid.narrow(1, y * height + padding, height - padding)\
         84                 .narrow(2, x * width + padding, width - padding)\
    ---> 85                 .copy_(tensor[k])
         86             k = k + 1
         87     return grid

    RuntimeError: The expanded size of the tensor (3) must match the existing size (640) at non-singleton dimension 0
w=torch.randn(10,3640640)
grid=torchvision.utils.make_网格(w.permute(0,2,3,1),nrow=5)
plt.imshow(网格)
---------------------------------------------------------------------------
运行时错误回溯(上次最近调用)
在()
1 w=torch.randn(10,3640640)
---->2格线=火炬视野。utils。制造格线(w.permute(0,2,3,1),nrow=5)
3 plt.imshow(网格)
/制作网格中的anaconda3/lib/python3.6/site-packages/torchvision-0.2.1-py3.6.egg/torchvision/utils.py
83网格。窄(1,y*高度+填充,高度-填充)\
84.窄(2,x*宽度+填充,宽度-填充)\
--->85.复制(张量[k])
86K=k+1
87返回网格
RuntimeError:张量(3)的扩展大小必须与非单态维度0的现有大小(640)匹配

这里发生了什么事?如何将一组随机生成的图像放入网格并显示它们?

您必须先转换为numpy

import numpy as np

def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

w = torch.randn(10,3,640,640)
grid = torchvision.utils.make_grid(w, nrow=10, padding=100)
show(grid)

你的代码中有一个小错误。例如,以下代码可以正常工作:

In [107]: import torchvision

# sample input (10 RGB images containing just Gaussian Noise)
In [108]: batch_tensor = torch.randn(*(10, 3, 256, 256))   # (N, C, H, W)

# make grid (2 rows and 5 columns) to display our 10 images
In [109]: grid_img = torchvision.utils.make_grid(batch_tensor, nrow=5)

# check shape
In [110]: grid_img.shape
Out[110]: torch.Size([3, 518, 1292])

# reshape and plot (because MPL needs channel as the last dimension)
In [111]: plt.imshow(grid_img.permute(1, 2, 0))
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[111]: <matplotlib.image.AxesImage at 0x7f62081ef080>
[107]中的
:导入torchvision
#样本输入(仅包含高斯噪声的10个RGB图像)
在[108]中:批次_张量=torch.randn(*(10,3,256,256))#(N,C,H,W)
#制作网格(2行5列)以显示我们的10幅图像
在[109]中:grid\u img=torchvision.utils.make\u grid(批处理张量,nrow=5)
#方格
In[110]:网格形状
Out[110]:火炬尺寸([35181292])
#重塑和打印(因为MPL需要通道作为最后一个维度)
[111]中:plt.imshow(grid_img.permute(1,2,0))
使用RGB数据将输入数据剪裁到imshow的有效范围([0..1]表示浮点数,[0..255]表示整数)。
出[111]:
其输出显示为:


谢谢你,卡马里奥23。我的错误是,没有将网格视为要显示的图像,这意味着网格必须进行重塑:-)