Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/305.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 如何从pytorch中给定位置的每行中获取值?_Python_Pytorch - Fatal编程技术网

Python 如何从pytorch中给定位置的每行中获取值?

Python 如何从pytorch中给定位置的每行中获取值?,python,pytorch,Python,Pytorch,如何根据包含每行位置的一维数组从二维火炬数组中获取值: 例如: a = torch.randn((5,5)) >>> a tensor([[ 0.0740, -0.3129, 0.7814, -0.0519, 1.3503], [ 1.1985, 0.2098, -0.0326, 0.3922, 0.5037], [-1.4334, 1.4047, -0.6607, -1.8024, -0.0088], [ 1.211

如何根据包含每行位置的一维数组从二维火炬数组中获取值:

例如:

a = torch.randn((5,5))
>>> a
tensor([[ 0.0740, -0.3129,  0.7814, -0.0519,  1.3503],
        [ 1.1985,  0.2098, -0.0326,  0.3922,  0.5037],
        [-1.4334,  1.4047, -0.6607, -1.8024, -0.0088],
        [ 1.2116,  0.5928,  1.4041,  1.0494, -0.1146],
        [ 0.4173,  1.0482,  0.5244, -2.1767,  0.5264]])

b = torch.randint(0,5, (5,))
>>> b
tensor([1, 0, 1, 3, 2])
desired output:
tensor([-0.3129,
        1.1985,
        1.4047,
        1.0494,
        0.5244])
我想在张量
b

例如:

a = torch.randn((5,5))
>>> a
tensor([[ 0.0740, -0.3129,  0.7814, -0.0519,  1.3503],
        [ 1.1985,  0.2098, -0.0326,  0.3922,  0.5037],
        [-1.4334,  1.4047, -0.6607, -1.8024, -0.0088],
        [ 1.2116,  0.5928,  1.4041,  1.0494, -0.1146],
        [ 0.4173,  1.0482,  0.5244, -2.1767,  0.5264]])

b = torch.randint(0,5, (5,))
>>> b
tensor([1, 0, 1, 3, 2])
desired output:
tensor([-0.3129,
        1.1985,
        1.4047,
        1.0494,
        0.5244])
这里,通过张量
b
选择给定位置的每个元素

我试过:

for index in range(b.size(-1)):
    val = torch.cat((val,a[index,b[index]].view(1,-1)), dim=0) if val is not None else a[index,b[index]].view(1,-1)


>>> val
tensor([[-0.3129],
        [ 1.1985],
        [ 1.4047],
        [ 1.0494],
        [ 0.5244]])
然而,有张量索引的方法吗? 我尝试了两种使用张量索引的解决方案,但没有一种有效。

您可以使用

a.聚集(1,b.取消聚集(1)) 张量([-0.3129], [ 1.1985], [ 1.4047], [ 1.0494], [ 0.5244]]) 或

>a[范围(len(a)),b].取消查询(1)
张量([-0.3129],
[ 1.1985],
[ 1.4047],
[ 1.0494],
[ 0.5244]])