Neural network 在pytorch中反向传播时自动更新自定义图层参数

Neural network 在pytorch中反向传播时自动更新自定义图层参数,neural-network,gradient,pytorch,backpropagation,Neural Network,Gradient,Pytorch,Backpropagation,我有一个pytorch自定义层,定义为: class MyCustomLayer(nn.Module): def __init__(self): super(MyCustomLayer, self).__init__() self.my_parameter = torch.rand(1, requires_grad = True) # the following allows the previously defined parameter to be recog

我有一个pytorch自定义层,定义为:

class MyCustomLayer(nn.Module):
  def __init__(self):
    super(MyCustomLayer, self).__init__()

    self.my_parameter = torch.rand(1, requires_grad = True)

    # the following allows the previously defined parameter to be recognized as a network parameter when instantiating the model
    self.my_registered_parameter = nn.ParameterList([nn.Parameter(self.my_parameter)])

  def forward(self, x):
    return x*self.my_parameter
然后定义使用自定义图层的网络:

class MyNet(nn.Module):
  def __init__(self):
    super(MyNet, self).__init__()
    self.layer1 = MyCustomLayer()

  def forward(self, x):
    x = self.layer1(x)
    return x
现在让我们实例化MyNet并观察问题:

# instantiate MyNet and run it over one input value
model = MyNet()
x = torch.tensor(torch.rand(1))
output = model(x)
criterion = nn.MSELoss()
loss = criterion(1, output)
loss.backward()
迭代模型参数显示自定义图层参数的
None

for p in model.parameters():
    print (p.grad)

None
直接访问该参数时,显示正确的
grad
值:

print(model.layer1.my_parameter.grad)

tensor([-1.4370])
这反过来又会阻止optim步骤自动更新内部参数,并使我不得不手动更新这些参数。有人知道我如何解决这个问题吗?

好吧! 我必须将自定义层中的参数变量调用切换到
nn.ParameterList
对象(即
返回x*self.my_registered_参数[0]
,而不是x*self.my_参数)。在本例中,这意味着将自定义层的参数调用forward方法更改为:

  def forward(self, x):
    return x*self.my_registered_parameter[0]
这是一个很好的地方,通过参考传递


现在optim会按预期更新所有参数

您所做的,即
返回x*self。我的注册参数[0]
有效,因为您使用注册参数计算梯度

调用
nn.Parameter
时,它返回一个新对象,因此用于操作的
self.my_参数
与注册的参数不相同

您可以通过将
my_参数
声明为
nn.参数

self.my_parameter = nn.Parameter(torch.rand(1, requires_grad = True))
self.my_registered_parameter= nn.ParameterList([self.some_parameter])
或者您根本不需要创建
my\u registered\u参数
变量。当您将self.my_参数声明为
nn.parameter
时,它将注册为参数