Warning: file_get_contents(/data/phpspider/zhask/data//catemap/4/maven/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Pytorch 如何在不使用python索引的情况下切片torch张量_Pytorch - Fatal编程技术网

Pytorch 如何在不使用python索引的情况下切片torch张量

Pytorch 如何在不使用python索引的情况下切片torch张量,pytorch,Pytorch,下面我的pytorch代码不断收到jit跟踪程序警告(在pytorch 1.1.0环境中),抱怨“pytorch 1.0跟踪程序警告:将张量转换为Python索引可能会…” 有没有一种方法可以在不使用python索引的情况下实现下面标记为(a)的代码行 N,C,H,W = input.size() Cout=4*C Hout=H//2 Wout=W//2 downsampled=torch.zeros([N,Cout,Hout,Wout], dtype= torch.FloatTensor) d

下面我的pytorch代码不断收到jit跟踪程序警告(在pytorch 1.1.0环境中),抱怨“pytorch 1.0跟踪程序警告:将张量转换为Python索引可能会…”

有没有一种方法可以在不使用python索引的情况下实现下面标记为(a)的代码行

N,C,H,W = input.size()
Cout=4*C
Hout=H//2
Wout=W//2
downsampled=torch.zeros([N,Cout,Hout,Wout], dtype= torch.FloatTensor)
downsampled[:,1:Cout:4,:,:]=input[:,:,0::2,1::2] ---- (A)

我确认jit跟踪程序不再抱怨Pytorch 1.2中的python索引(正如Umang Gupta所评论的)

顺便说一句,我提出了一个没有切片(但仍然使用索引)的实现,如下所示:

import torch

input=torch.arange(100)
input=input.view(10,10)
input=input[None, None, ...].expand(2,3,10,10) #torch.Size([2,3,10,10])

N,C,H,W=input.size()
Cout=4*C
Hout=H//2
Wout=W//2

downsampled=torch.zeros([N,Cout,Hout,Wout],dtype=torch.int8) #torch.Size([2,12,5,5])

dim2_idx=torch.tensor([k for k in range(0,H,2)])
dim3_idx=torch.tensor([k for k in range(1,W,2)])
sliced_input=input.index_select(2,dim2_idx).index_select(3,dim3_idx) #torch.Size([2,3,5,5])

#downsampled.index_select(1,torch.tensor([k for k in range(1,Cout,4)]))=temp <---Error: Can't assign to function call

for idx in range(1,Cout,4):
    downsampled[:,idx,:,:]=sliced_input[:,idx//4,:,:]
导入火炬
输入=torch.arange(100)
输入=输入。视图(10,10)
输入=输入[None,None,…]。展开(2,3,10,10)#火炬。尺寸([2,3,10,10])
N、 C,H,W=input.size()
Cout=4*C
Hout=H//2
Wout=W//2
下采样=火炬零点([N,Cout,Hout,Wout],dtype=火炬int8)#火炬尺寸([2,12,5,5])
dim2_idx=火炬张量([k表示范围(0,H,2)])
dim3_idx=火炬张量([k表示范围(1,W,2)])
切片输入=输入。索引选择(2,dim2\U idx)。索引选择(3,dim3\U idx)#火炬大小([2,3,5,5])

#降采样索引_选择(1,火炬张量([k代表范围内的k(1,Cout,4)]))=温度至少在1.2范围内工作。