Python Pytorch批量索引

Python Pytorch批量索引,python,indexing,pytorch,batch-processing,Python,Indexing,Pytorch,Batch Processing,因此,我的网络输出如下所示: output = tensor([[[ 0.0868, -0.2623], [ 0.0716, -0.2668], [ 0.0584, -0.2549], [ 0.0482, -0.2386], [ 0.0410, -0.2234], [ 0.0362, -0.2111], [ 0.0333, -0.2018], [ 0.0318, -0.1951], [ 0.0311, -0.1904

因此,我的网络输出如下所示:

output = tensor([[[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.0410, -0.2234],
     [ 0.0362, -0.2111],
     [ 0.0333, -0.2018],
     [ 0.0318, -0.1951],
     [ 0.0311, -0.1904],
     [ 0.0310, -0.1873],
     [ 0.0312, -0.1851],
     [ 0.0315, -0.1837],
     [ 0.0318, -0.1828],
     [ 0.0322, -0.1822],
     [ 0.0324, -0.1819],
     [ 0.0327, -0.1817],
     [ 0.0328, -0.1815],
     [ 0.0330, -0.1815],
     [ 0.0331, -0.1814],
     [ 0.0332, -0.1814],
     [ 0.0333, -0.1814],
     [ 0.0333, -0.1814],
     [ 0.0334, -0.1814],
     [ 0.0334, -0.1814],
     [ 0.0334, -0.1814]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.0410, -0.2234],
     [ 0.0362, -0.2111],
     [ 0.0333, -0.2018],
     [ 0.0318, -0.1951],
     [ 0.0311, -0.1904],
     [ 0.0310, -0.1873],
     [ 0.0312, -0.1851],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]]])
它是
[8,24,2]
现在8是我的批量大小。我想从每个批次中获得一个数据点,位于以下位置:

index = tensor([24, 10,  3,  3,  1,  1,  1,  0])
第一批的第24个值,第二批的第10个值,依此类推

现在我很难理解语法。 我试过了

torch.gather(output, 0, index)
但它一直告诉我,我的尺寸不匹配。 努力

output[ : ,index]
只需获取每个批次的所有索引的值。
要获得这些值,这里的正确语法是什么?

要在每个批次中只选择一个元素,您需要枚举批次索引,这可以很容易地使用

output[torch.arange(output.size(0)),index]

这实质上是在枚举张量和索引张量之间创建元组来访问数据,这导致索引输出[0,24],
输出[1,10]
等。首先,对于输出形状[8,24,2],第二维度上的最大索引可以是23,所以我将你的索引修改为

index = torch.tensor([23, 10,  3,  3,  1,  1,  1,  0])
output = torch.randn((8,24,2)) # Toy data to represent your output
最简单的解决方案是使用for循环

data_pts = torch.zeros((8,2)) # Tensor to store desired values

for i,j in enumerate(index):
    data_pts[i, :] = output[i, j, :]
但是,如果要对索引进行矢量化,只需对所有维度进行索引即可。比如说,

data_pts_vectorized = output[range(8), index, :] 
由于索引向量是有序的,因此可以使用
范围
生成第一维度索引

您可以确认这两种方法具有相同的结果

assert(torch.all(data_pts == data_pts_vectorized))