Python 如何根据pytorch中的另一张量选择索引

Python 如何根据pytorch中的另一张量选择索引,python,pytorch,indices,Python,Pytorch,Indices,这项任务似乎很简单,但我不知道怎么做 所以我有两个张量: 具有形状(2,5,2)的索引张量索引,其中最后一个维度对应于x和y维度的索引 形状为(2,5,2,16,16)的“值张量”值,其中我希望最后两个维度使用x和y索引进行选择 更具体地说,指数在0到15之间,我想得到一个输出: out = value[:, :, :, x_indices, y_indices] 因此,输出的形状应为(2,5,2)。有人能帮我吗?非常感谢 编辑: 我尝试了gather的建议,但不幸的是,它似乎不起作用(我

这项任务似乎很简单,但我不知道怎么做

所以我有两个张量:

  • 具有形状
    (2,5,2)
    的索引张量
    索引
    ,其中最后一个维度对应于x和y维度的索引
  • 形状为(2,5,2,16,16)的“值张量”
    ,其中我希望最后两个维度使用x和y索引进行选择
更具体地说,指数在0到15之间,我想得到一个输出:

out = value[:, :, :, x_indices, y_indices]
因此,输出的形状应为
(2,5,2)
。有人能帮我吗?非常感谢

编辑:

我尝试了gather的建议,但不幸的是,它似乎不起作用(我更改了维度,但没关系):

首先,我生成一个坐标网格:

y_t = torch.linspace(-1., 1., 16, device='cpu').reshape(16, 1).repeat(1, 16).unsqueeze(-1)
x_t = torch.linspace(-1., 1., 16, device='cpu').reshape(1, 16).repeat(16, 1).unsqueeze(-1)
grid = torch.cat((y_t, x_t), dim=-1).permute(2, 0, 1).unsqueeze(0)
grid = grid.unsqueeze(1).repeat(1, 3, 1, 1, 1)
在下一步中,我将创建一些索引。在这种情况下,我总是采用索引1:

indices = torch.ones([1, 3, 2], dtype=torch.int64)
接下来,我将使用您的方法:

indices = indices.unsqueeze(-1).unsqueeze(-1)
new_coords = torch.gather(grid, -1, indices).squeeze(-1).squeeze(-1)
最后,我手动为x和y坐标选择索引1:

new_coords_manual = grid[:, :, :, 1, 1]
这将输出以下新坐标:

new_coords
tensor([[[-1.0000, -0.8667],
         [-1.0000, -0.8667],
         [-1.0000, -0.8667]]])

new_coords_manual
tensor([[[-0.8667, -0.8667],
         [-0.8667, -0.8667],
         [-0.8667, -0.8667]]])

如您所见,它只适用于一维。您知道如何解决此问题吗?

您可以将前三个轴展平并应用:

如文档页面所述,这将执行:

out[i][j][k]=input[i][index[i][j][k][k]#如果dim==1


我明白了,再次感谢@Ivan的帮助!:)

问题是,我在最后一个维度上没有被压缩,而我应该在中间维度中被挤压,所以指数在最后:

y_t = torch.linspace(-1., 1., 16, device='cpu').reshape(16, 1).repeat(1, 16).unsqueeze(-1)
x_t = torch.linspace(-1., 1., 16, device='cpu').reshape(1, 16).repeat(16, 1).unsqueeze(-1)
grid = torch.cat((y_t, x_t), dim=-1).permute(2, 0, 1).unsqueeze(0)
grid = grid.unsqueeze(1).repeat(2, 3, 1, 1, 1)

indices = torch.ones([2, 3, 2], dtype=torch.int64).unsqueeze(-2).unsqueeze(-2)
new_coords = torch.gather(grid, 3, indices).squeeze(-2).squeeze(-2)

new_coords_manual = grid[:, :, :, 1, 1]

现在
new\u coords
equals
new\u coords\u manual

你能不能给出一个最小的例子
index
value
以及所需的输出?在制作
new\u coords\u manual
时可以获得所需的输出谢谢你的帮助!这确实适用于批量大小为1的情况,但似乎在批量大小大于1时也会遇到同样的问题:/I还尝试在x和y坐标中拆分问题,并应用
index_y=index[:,:,0]。unsqueze(-1)。unsqueze(-1)。unsqueze(-1)。unsqueze(-1)
后跟
new_y=torch.gather(grid,3,index_y)。挤压(-1)。挤压(-1)
。在对x值执行相同操作并在dim=4上聚集之后,我连接了张量。但我从你的第一个建议中得到了确切的结果。
y_t = torch.linspace(-1., 1., 16, device='cpu').reshape(16, 1).repeat(1, 16).unsqueeze(-1)
x_t = torch.linspace(-1., 1., 16, device='cpu').reshape(1, 16).repeat(16, 1).unsqueeze(-1)
grid = torch.cat((y_t, x_t), dim=-1).permute(2, 0, 1).unsqueeze(0)
grid = grid.unsqueeze(1).repeat(2, 3, 1, 1, 1)

indices = torch.ones([2, 3, 2], dtype=torch.int64).unsqueeze(-2).unsqueeze(-2)
new_coords = torch.gather(grid, 3, indices).squeeze(-2).squeeze(-2)

new_coords_manual = grid[:, :, :, 1, 1]