Pytorch 用_convendin forward编写模块的torch脚本

Pytorch 用_convendin forward编写模块的torch脚本,pytorch,torchscript,Pytorch,Torchscript,我使用的是PyTorch 1.4,需要导出一个模型,该模型在正向循环中包含卷积: class MyCell(torch.nn.Module): def __init__(self): super(MyCell, self).__init__() def forward(self, x): for i in range(5): conv = torch.nn.Conv1d(1, 1, 2*i+3) x

我使用的是PyTorch 1.4,需要导出一个模型,该模型在
正向
循环中包含卷积:

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()

    def forward(self, x):
        for i in range(5):
            conv = torch.nn.Conv1d(1, 1, 2*i+3)
            x = torch.nn.Relu()(conv(x))
        return x


torch.jit.script(MyCell())
这会产生以下错误:

RuntimeError: 
Arguments for call are not valid.
The following variants are available:

  _single(float[1] x) -> (float[]):
  Expected a value of type 'List[float]' for argument 'x' but instead found type 'Tensor'.

  _single(int[1] x) -> (int[]):
  Expected a value of type 'List[int]' for argument 'x' but instead found type 'Tensor'.

The original call is:
  File "***/torch/nn/modules/conv.py", line 187
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros'):
        kernel_size = _single(kernel_size)
                      ~~~~~~~ <--- HERE
        stride = _single(stride)
        padding = _single(padding)
'Conv1d.__init__' is being compiled since it was called from 'Conv1d'
  File "***", line ***
    def forward(self, x):
        for _ in range(5):
            conv = torch.nn.Conv1d(1, 1, 2*i+3)
                   ~~~~~~~~~~~~~~~ <--- HERE
            x = torch.nn.Relu()(conv(x))
        return x
'Conv1d' is being compiled since it was called from 'MyCell.forward'
  File "***", line ***
    def forward(self, x, h):
        for _ in range(5):
            conv = torch.nn.Conv1d(1, 1, 2*i+3)
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            x = torch.nn.Relu()(conv(x))
        return x
相反,这将提供:

RuntimeError: 
Module 'MyCell' has no attribute 'conv' (This attribute exists on the Python module, but we failed to convert Python type: 'list' to a TorchScript type.):
  File "***", line ***
    def forward(self, x):
        for i in range(len(self.conv)):
                           ~~~~~~~~~ <--- HERE
            x = torch.nn.Relu()(self.conv[i](x))
        return x
运行时错误:
模块“MyCell”没有属性“conv”(Python模块上存在此属性,但我们无法将Python类型“list”转换为TorchScript类型):
文件“***”,第行***
def前进(自身,x):
对于范围内的i(len(self.conv)):
~~~~~~~您可以按以下方式使用

另外,请注意,由于中提到的错误,您当前无法为
nn.ModuleList
下标,但请使用下面提到的解决方法

class MyCell(nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.conv = nn.ModuleList([torch.nn.Conv1d(1, 1, 2*i+3) for i in range(5)])
        self.relu = nn.ReLU()

    def forward(self, x):
        for mod in self.conv:
            x = self.relu(mod(x))
        return x

>>> torch.jit.script(MyCell())
RecursiveScriptModule(
  original_name=MyCell
  (conv): RecursiveScriptModule(
    original_name=ModuleList
    (0): RecursiveScriptModule(original_name=Conv1d)
    (1): RecursiveScriptModule(original_name=Conv1d)
    (2): RecursiveScriptModule(original_name=Conv1d)
    (3): RecursiveScriptModule(original_name=Conv1d)
    (4): RecursiveScriptModule(original_name=Conv1d)
  )
  (relu): RecursiveScriptModule(original_name=ReLU)
)
作为[建议]的替代方案,您可以通过以下方式定义网络功能:

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.w = []
        for i in range(5):
            self.w.append( torch.Tensor( 1, 1, 2*i+3 ) )
            # init w[i] here, maybe make it "requires grad" 

    def forward(self, x):
        for i in range(5):
            x = torch.nn.functional.conv1d( x, self.w[i] )
            x = torch.nn.functional.relu( x )
        return x
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.w = []
        for i in range(5):
            self.w.append( torch.Tensor( 1, 1, 2*i+3 ) )
            # init w[i] here, maybe make it "requires grad" 

    def forward(self, x):
        for i in range(5):
            x = torch.nn.functional.conv1d( x, self.w[i] )
            x = torch.nn.functional.relu( x )
        return x