Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/300.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 如何动态索引pytorch中的张量?_Python_Deep Learning_Pytorch_Torch_Tensor - Fatal编程技术网

Python 如何动态索引pytorch中的张量?

Python 如何动态索引pytorch中的张量?,python,deep-learning,pytorch,torch,tensor,Python,Deep Learning,Pytorch,Torch,Tensor,例如,我得到一个张量: tensor = torch.rand(12, 512, 768) 我得到了一个索引列表,比如: [0,2,3,400,5,32,7,8,321,107,100,511] 在给定索引列表的情况下,我希望从维度2上的512个元素中选择1个元素。然后张量的大小会变成(12,1768) 有办法吗?是的,您可以使用索引直接将其切片,然后使用将二维张量提升为三维张量: # inputs In [6]: tensor = torch.rand(12, 512, 768) In [

例如,我得到一个张量:

tensor = torch.rand(12, 512, 768)
我得到了一个索引列表,比如:

[0,2,3,400,5,32,7,8,321,107,100,511]
在给定索引列表的情况下,我希望从维度2上的512个元素中选择1个元素。然后张量的大小会变成
(12,1768)


有办法吗?

是的,您可以使用索引直接将其切片,然后使用将二维张量提升为三维张量:

# inputs
In [6]: tensor = torch.rand(12, 512, 768)
In [7]: idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]

# slice using the index and then put a singleton dimension along axis 1
In [8]: for idx in idx_list:
   ...:     sampled_tensor = torch.unsqueeze(tensor[:, idx, :], 1)
   ...:     print(sampled_tensor.shape)
   ...:     
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])

或者,如果您想要更简洁的代码并且不想使用,请使用:

In [11]: for idx in idx_list:
    ...:     sampled_tensor = tensor[:, [idx], :]
    ...:     print(sampled_tensor.shape)
    ...:     
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])


注意:如果您希望仅对
idx\u列表中的一个
idx
进行切片,则无需对
循环使用
,还有一种方法,即使用PyTorch,并使用索引和:

当你调用
张量[:,idx\u张量,:]
你会得到一个形状的张量:
(12,lenu of_idx\u list,768)

其中,第二个维度取决于索引的数量

使用此张量可拆分为一系列形状的张量:
(12,1768)

所以最后,
张量列表
包含以下形状的张量:

[torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768])]
[torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768])]