Python Pytorch张量索引:如何通过包含索引的张量收集行

Python Pytorch张量索引:如何通过包含索引的张量收集行,python,indexing,pytorch,Python,Indexing,Pytorch,我有张量: ids:形状(7000,1),包含像[[1]、[0]、[2]、…]这样的索引。 x:形状(7000,3,255) idstensor对应选择的x的粗体标注尺寸索引进行编码。 我想在结果向量中收集选定的切片: 结果:形状(7000255) 背景: 我对3个元素中的每一个都有一些分数(shape=(7000,3)),只想选择分数最高的一个。因此,我使用了这个函数 ids = torch.argmax(scores,1,True) 给我最大的ID。我已经尝试使用“聚集”功能执行此操作:

我有张量:

ids:形状(7000,1),包含像
[[1]、[0]、[2]、…]这样的索引。

x:形状(7000,3,255)

ids
tensor对应选择的
x
的粗体标注尺寸索引进行编码。 我想在结果向量中收集选定的切片:

结果:形状(7000255)

背景:

我对3个元素中的每一个都有一些分数(shape=(7000,3)),只想选择分数最高的一个。因此,我使用了这个函数

ids = torch.argmax(scores,1,True)
给我最大的ID。我已经尝试使用“聚集”功能执行此操作:

result = x.gather(1,ids)

但这不起作用。

这里有一个解决方案,您可以寻找

ids = ids.repeat(1, 255).view(-1, 1, 255)
举例如下:

x=torch.arange(24).视图(4,3,2)
"""
张量([[0,1],,
[ 2,  3],
[ 4,  5]],
[[ 6,  7],
[ 8,  9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]])
"""
ids=torch.randint(0,3,size=(4,1))
"""
张量([[0],
[2],
[0],
[2]])
"""
idx=ids。重复(1,2)。查看(4,1,2)
"""
张量([[0,0]],
[[2, 2]],
[[0, 0]],
[[2, 2]]])
"""
火炬收集(x,1,idx)
"""
张量([[0,1]],
[[10, 11]],
[[12, 13]],
[[22, 23]]])
"""

以David Ng为例,我找到了另一种方法:

idx = ids.flatten() + torch.arange(0,4*3,3)

tensor([ 0,  5,  6, 11])



x.view(-1,2)[idx]

tensor([[ 0,  1],
        [10, 11],
        [12, 13],
        [22, 23]])

另一种解决方案可以在维度更高的情况下提供更好的内存读取模式

# data
x = torch.arange(60).reshape(3, 4, 5)
# index
y = torch.randint(0, 4, (12,), dtype=torch.int64).reshape(3, 4)
# result
z = x[torch.arange(x.shape[0]).repeat_interleave(x.shape[1]), y.flatten()]
z = z.reshape(x.shape)
x,y,z的示例结果如下

Tensor([[[ 0,  1,  2,  3,  4],
     [ 5,  6,  7,  8,  9],
     [10, 11, 12, 13, 14],
     [15, 16, 17, 18, 19]],

    [[20, 21, 22, 23, 24],
     [25, 26, 27, 28, 29],
     [30, 31, 32, 33, 34],
     [35, 36, 37, 38, 39]],

    [[40, 41, 42, 43, 44],
     [45, 46, 47, 48, 49],
     [50, 51, 52, 53, 54],
     [55, 56, 57, 58, 59]]])
tensor([[1, 1, 2, 3],
    [3, 1, 1, 0],
    [1, 1, 1, 1]])
tensor([[[ 5,  6,  7,  8,  9],
     [ 5,  6,  7,  8,  9],
     [10, 11, 12, 13, 14],
     [15, 16, 17, 18, 19]],

    [[35, 36, 37, 38, 39],
     [25, 26, 27, 28, 29],
     [25, 26, 27, 28, 29],
     [20, 21, 22, 23, 24]],

    [[45, 46, 47, 48, 49],
     [45, 46, 47, 48, 49],
     [45, 46, 47, 48, 49],
     [45, 46, 47, 48, 49]]])

很好用!非常感谢你。我还添加了我自己找到的另一个解决方案。但我不确定哪一个更快