Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/308.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python PyTorch Tensor.index_select()如何计算张量输出?_Python_Indexing_Pytorch_Tensor - Fatal编程技术网

Python PyTorch Tensor.index_select()如何计算张量输出?

Python PyTorch Tensor.index_select()如何计算张量输出?,python,indexing,pytorch,tensor,Python,Indexing,Pytorch,Tensor,我无法理解复杂的索引——张量的非连续索引是如何工作的。下面是一个示例代码及其输出 import torch def describe(x): print("Type: {}".format(x.type())) print("Shape/size: {}".format(x.shape)) print("Values: \n{}".format(x)) indices = torch.LongTensor([0,2])

我无法理解复杂的索引——张量的非连续索引是如何工作的。下面是一个示例代码及其输出

import torch

def describe(x):
  print("Type: {}".format(x.type()))
  print("Shape/size: {}".format(x.shape))
  print("Values: \n{}".format(x))


indices = torch.LongTensor([0,2])
x = torch.arange(6).view(2,3)
describe(torch.index_select(x, dim=1, index=indices))
将输出返回为

类型:火炬。传感器形状/尺寸:火炬。尺寸([2,2])值: 张量([[0,2], [3,5]]

有人能解释它是如何到达这个输出张量的吗?
谢谢

您正在从第一个轴(
dim=0
)上的
x
中选择第一个(
索引[0]
)和第三个(
索引[1]
2
)张量。本质上,使用
dim=1
与使用
x[:,index]
在第二个轴上进行直接索引的效果相同

>>> x
tensor([[0, 1, 2],
        [3, 4, 5]])
因此,选择列(因为您看到的是
dim=1
,而不是
dim=0
)哪些索引位于
索引中。想象一下,将一个简单的列表
[0,2]
作为
索引

>>> indices = [0, 2]

>>> x[:, indices[0]] # same as x[:, 0]
tensor([0, 3])

>>> x[:, indices[1]] # same as x[:, 2]
tensor([2, 5])
因此,将索引作为
火炬传递。Tensor
允许您直接对索引的所有元素进行索引,即列
0
2
。类似于NumPy的索引工作方式

>>> x[:, indices]
tensor([[0, 2],
        [3, 5]])

下面是另一个例子来帮助您了解它是如何工作的。由于将
x
定义为
x=torch.arange(9).视图(3,3)
,因此我们有3行(也称为
dim=0
)和3列(也称为
dim=1

注:
torch.index\u select(x,dim,index)
等同于
x.index\u select(dim,index)

>>> indices
tensor([0, 2]) # namely 'first' and 'third'

>>> x = torch.arange(9).view(3, 3)
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> x.index_select(0, indices) # select first and third rows
tensor([[0, 1, 2],
        [6, 7, 8]])

>>> x.index_select(1, indices) # select first and third columns
tensor([[0, 2],
        [3, 5],
        [6, 8]])