Pytorch:根据索引张量从三维张量中选择列

Pytorch:根据索引张量从三维张量中选择列,pytorch,tensor,Pytorch,Tensor,我有一个维度为[BxLxD]的三维张量M,和一个维度为[B,1]的一维张量idx,包含范围为(0,L-1)的列索引。我想创建一个维度为[BxD]的二维张量N,这样N[I,j]=M[I,idx[I],j]。如何有效地做到这一点 例如: B,L,D = 2,4,2 M = torch.rand(B,L,D) > tensor([[[0.0612, 0.7385], [0.7675, 0.3444], [0.9129, 0.7601],

我有一个维度为
[BxLxD]
的三维张量
M
,和一个维度为
[B,1]
的一维张量
idx
,包含范围为
(0,L-1)
的列索引。我想创建一个维度为
[BxD]
的二维张量
N
,这样
N[I,j]=M[I,idx[I],j]
。如何有效地做到这一点

例如:

B,L,D = 2,4,2

M = torch.rand(B,L,D)

>

tensor([[[0.0612, 0.7385],
         [0.7675, 0.3444],
         [0.9129, 0.7601],
         [0.0567, 0.5602]],

        [[0.5450, 0.3749],
         [0.4212, 0.9243],
         [0.1965, 0.9654],
         [0.7230, 0.6295]]])


idx = torch.randint(0, L, size = (B,))

>

tensor([3, 0])

N = get_N(M, idx)

Expected output:

>

tensor([[0.0567, 0.5602], 
       [0.5450, 0.3749]])
谢谢

import torch

B,L,D = 2,4,2

def get_N(M, idx):
    return M[torch.arange(B), idx, :].squeeze()

M = torch.tensor([[[0.0612, 0.7385],
                   [0.7675, 0.3444],
                   [0.9129, 0.7601],
                   [0.0567, 0.5602]],

                   [[0.5450, 0.3749],
                   [0.4212, 0.9243],
                   [0.1965, 0.9654],
                   [0.7230, 0.6295]]])
idx = torch.tensor([3,0])
N = get_N(M, idx)
print(N)
结果:

tensor([[0.0567, 0.5602],
        [0.5450, 0.3749]])
沿二维切片

结果:

tensor([[0.0567, 0.5602],
        [0.5450, 0.3749]])

两个维度。

你能提供一个如果你觉得我的答案有用,请考虑接受它。如果你发现它不起作用,一定要告诉我出了什么问题。:)如果你觉得我的答案有用,请考虑接受。如果你发现它不起作用,一定要告诉我出了什么问题。:)