Matrix 使用index_select将一个PyTorch张量索引为另一个PyTorch张量

Matrix 使用index_select将一个PyTorch张量索引为另一个PyTorch张量,matrix,indexing,pytorch,Matrix,Indexing,Pytorch,我有一个3 x 3的PyTorch LongTensor,看起来像这样: A = [0, 0, 0] [1, 2, 2] [1, 2, 3] 我想我们用它来索引一个4 x 2的浮动张量,像这样: B = [0.4, 0.5] [1.2, 1.4] [0.8, 1.9] [2.4, 2.9] 我的预期输出是下面的2 x 3 x 3浮点张量: C[0,:,:] = [0.4, 0.4, 0.4] [1.2, 0.8,

我有一个3 x 3的PyTorch LongTensor,看起来像这样:

A = 
    [0, 0, 0]
    [1, 2, 2]
    [1, 2, 3]
我想我们用它来索引一个4 x 2的浮动张量,像这样:

B = 
    [0.4, 0.5]
    [1.2, 1.4]
    [0.8, 1.9]
    [2.4, 2.9]
我的预期输出是下面的2 x 3 x 3浮点张量:

C[0,:,:] = 
    [0.4, 0.4, 0.4]
    [1.2, 0.8, 0.8]
    [1.2, 0.8, 2.4]

C[1,:,:] =
    [0.5, 0.5, 0.5]
    [1.4, 1.9, 1.9]
    [1.4, 1.9, 2.9]
换句话说,矩阵
A
是索引和广播矩阵
B
A
B
的索引矩阵,因此该操作本质上是一个索引操作

如何使用该函数实现这一点?如果解决方案涉及添加或排列维度,那没关系。

使用
index\u select()
要求索引值位于向量而不是张量中。但只要格式正确,该函数就可以为您处理广播。最后一件必须做的事情是重塑输出,我相信是由于广播

成功执行此操作的一个班轮是

torch.index_select(B, 0, A.view(-1)).view(3,-1,2).permute(2,0,1)
A.view(-1)
将索引矩阵矢量化

\视图(3,-1,2)
重新塑造回索引矩阵的形状,但考虑到新的额外维度大小2(因为我正在索引一个nx2矩阵)

最后,
\uuu.permute(2,0,1)
重新塑造矩阵,以便输出在单独的通道(而不是每列)中查看
B
的每个维度