Indexing 如何利用torch.topk()提供的索引?

Indexing 如何利用torch.topk()提供的索引?,indexing,pytorch,tensor,Indexing,Pytorch,Tensor,假设我有一个pytorch张量x的形状[N,N\u g,2]。它可以看作是N*N_g2d向量。具体而言,x[i,j,:]是ith批次中jth组的2d向量 现在我试图得到每组中前5个长度向量的坐标。因此,我尝试了以下方法: (i) 首先,我使用x_len=(x**2).sum(dim=2).sqrt()来计算它们的长度,结果是x_len.shape=[N,N_g] (ii)然后我使用tk=x_len.topk(5)得到每组中前5个长度 (iii)所需的输出将是形状[N,5,2]的张量x_top5。

假设我有一个pytorch张量
x
的形状
[N,N\u g,2]
。它可以看作是
N*N_g
2d向量。具体而言,
x[i,j,:]
i
th批次中
j
th组的2d向量

现在我试图得到每组中前5个长度向量的坐标。因此,我尝试了以下方法:

(i) 首先,我使用
x_len=(x**2).sum(dim=2).sqrt()
来计算它们的长度,结果是
x_len.shape=[N,N_g]

(ii)然后我使用
tk=x_len.topk(5)
得到每组中前5个长度

(iii)所需的输出将是形状
[N,5,2]
的张量
x_top5
。我自然想到了使用
tk.index
来索引
x
,以获得
x\u top5
。但我失败了,因为似乎不支持这种索引

我该怎么做


一个简单的例子:

x = torch.randn(10,10,2) # N=10 is the batchsize, N_g=10 is the group size
x_len = (x**2).sum(dim=2).sqrt()
tk = x_len.topk(5)

x_top5 = x[tk.indices]

print(x_top5.shape)
# torch.Size([10, 5, 10, 2])

但是,这使
x_top5
成为形状的张量
[10,5,10,2]
,而不是所需的
[10,5,2]

您没有设置dim参数<代码>tk=x_len.topk(5,dim=1)用于每组前5个长度。对于输出:
x[:,tk.index,:]
@DimitriK.Sifoua默认情况下,dim参数为
-1
,与本例中的
dim=1
相同,因此不会产生任何差异。此外,如您所建议的,索引会给出
[10,10,5,2]
而不是
[10,5,2]
。dim参数为OK。问题来自索引,因为您使用二维张量而不是一维张量。这将工作
torch.cat([*map(lambda i:x[i:i+1,tk.index[i],:],range(x.shape[0]))))