Warning: file_get_contents(/data/phpspider/zhask/data//catemap/1/visual-studio-2008/2.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python PyTorch张量topk表示维度上的每个张量_Python_Pytorch_Tensor - Fatal编程技术网

Python PyTorch张量topk表示维度上的每个张量

Python PyTorch张量topk表示维度上的每个张量,python,pytorch,tensor,Python,Pytorch,Tensor,我有下面的张量 inp = tensor([[[ 0.0000e+00, 5.7100e+02, -6.9846e+00], [ 0.0000e+00, 4.4070e+03, -7.1008e+00], [ 0.0000e+00, 3.0300e+02, -7.2226e+00], [ 0.0000e+00, 6.8000e+01, -7.2777e+00], [ 1.0000e+00, 5.7100e+02, -6.9846e+00],

我有下面的张量

inp = tensor([[[ 0.0000e+00,  5.7100e+02, -6.9846e+00],
     [ 0.0000e+00,  4.4070e+03, -7.1008e+00],
     [ 0.0000e+00,  3.0300e+02, -7.2226e+00],
     [ 0.0000e+00,  6.8000e+01, -7.2777e+00],
     [ 1.0000e+00,  5.7100e+02, -6.9846e+00],
     [ 1.0000e+00,  4.4070e+03, -7.1008e+00],
     [ 1.0000e+00,  3.0300e+02, -7.2226e+00],
     [ 1.0000e+00,  6.8000e+01, -7.2777e+00]],

    [[ 0.0000e+00,  2.1610e+03, -7.0754e+00],
     [ 0.0000e+00,  6.8000e+01, -7.2259e+00],
     [ 0.0000e+00,  1.0620e+03, -7.2920e+00],
     [ 0.0000e+00,  2.9330e+03, -7.3009e+00],
     [ 1.0000e+00,  2.1610e+03, -7.0754e+00],
     [ 1.0000e+00,  6.8000e+01, -7.2259e+00],
     [ 1.0000e+00,  1.0620e+03, -7.2920e+00],
     [ 1.0000e+00,  2.9330e+03, -7.3009e+00]],

    [[ 0.0000e+00,  4.4070e+03, -7.1947e+00],
     [ 0.0000e+00,  3.5600e+02, -7.2958e+00],
     [ 0.0000e+00,  3.0300e+02, -7.3232e+00],
     [ 0.0000e+00,  1.2910e+03, -7.3615e+00],
     [ 1.0000e+00,  4.4070e+03, -7.1947e+00],
     [ 1.0000e+00,  3.5600e+02, -7.2958e+00],
     [ 1.0000e+00,  3.0300e+02, -7.3232e+00],
     [ 1.0000e+00,  1.2910e+03, -7.3615e+00]]])
形状

torch.Size([3, 8, 3])
我想在dim1中找到topk(k=4)元素,其中要排序的值是dim2(负值)。由此产生的张量形状应为:

torch.Size([3, 4, 3])
我知道如何对单个张量进行topk,但如何同时对多个批次进行topk?

我是这样做的:

val, ind = inp[:, :, 2].squeeze().topk(k=4, dim=1, sorted=True)
new_ind = ind.unsqueeze(-1).repeat(1,1,3)
result = inp.gather(1, new_ind)

我不知道这是否是最好的方法,但它奏效了。

一种方法是将和组合起来,如下所示:

>>> i = torch.arange(x.shape[0]).reshape(x.shape[0], 1, 1)
>>> j = idx.reshape(x.shape[0], -1, 1)
>>> k = torch.arange(x.shape[2]).reshape(1, 1, x.shape[2])
我以形状为(3,4,3)和
k
的随机张量
x
为例

导入火炬 >>>x=火炬的兰特(3,4,3) >>>x 张量([[0.0256,0.7366,0.2528], [0.5596, 0.9450, 0.5795], [0.8265, 0.5469, 0.8304], [0.4223, 0.5206, 0.2898]], [[0.2159, 0.0369, 0.6869], [0.4556, 0.5804, 0.3169], [0.8194, 0.5240, 0.0055], [0.8357, 0.4162, 0.3740]], [[0.3849, 0.0223, 0.9951], [0.2872, 0.5952, 0.6570], [0.1433, 0.8450, 0.6557], [0.0270, 0.9176, 0.3904]]]) 现在,按照所需维度(最后一个维度)对张量进行排序,并获得索引:

>>> _, idx = torch.sort(x[:, :, -1])
>>> k = 2
>>> idx = idx[:, :k]
# idx is = 
tensor([[0, 3],
        [2, 1],
        [3, 2]])
现在生成三对索引
(i,j,k)
,对原始张量进行切片,如下所示:

>>> i = torch.arange(x.shape[0]).reshape(x.shape[0], 1, 1)
>>> j = idx.reshape(x.shape[0], -1, 1)
>>> k = torch.arange(x.shape[2]).reshape(1, 1, x.shape[2])
请注意,一旦您通过
(i,j,k)
索引任何内容,它们将采用
(x.shape[0],k,x.shape[2])
,这是此处所需的输出形状。 现在只需通过i、j和k索引
x

>>> x[i, j, k]
tensor([[[0.0256, 0.7366, 0.2528],
         [0.4223, 0.5206, 0.2898]],

        [[0.8194, 0.5240, 0.0055],
         [0.4556, 0.5804, 0.3169]],

        [[0.0270, 0.9176, 0.3904],
         [0.1433, 0.8450, 0.6557]]])
本质上,我遵循的一般方法是通过索引数组创建相应的张量访问模式,然后使用这些数组作为索引直接切片张量


我实际上是为了升序排序而做的,所以这里我得到了top-k最小元素。一个简单的解决方法是使用
torch.sort(x[:,:,-1],descending=True)

使用
gather()
方法的好例子。不错。