仅计算Pytorch中前端网络的梯度

仅计算Pytorch中前端网络的梯度,pytorch,Pytorch,我有一个非常简单的问题 假设我有两个网络要训练(即net1、net2)。 net1的输出将在训练时输入net2。 在我的情况下,我只想更新net1: optimizer=Optimizer(net1.parameters(), **kwargs) loss=net2(net1(x)) loss.backward() optimizer.step() 虽然这将实现我的目标,但它会占用太多的冗余内存,因为这将计算net2的梯度(导致OOM错误)。 因此,我尝试了几种方法来解决这个问题: 火炬编号和

我有一个非常简单的问题

假设我有两个网络要训练(即net1、net2)。 net1的输出将在训练时输入net2。 在我的情况下,我只想更新net1:

optimizer=Optimizer(net1.parameters(), **kwargs)
loss=net2(net1(x))
loss.backward()
optimizer.step()
虽然这将实现我的目标,但它会占用太多的冗余内存,因为这将计算net2的梯度(导致OOM错误)。 因此,我尝试了几种方法来解决这个问题:

  • 火炬编号和等级:
  • 没有提高OOM,但删除了所有渐变,包括net1中的渐变

  • 需要_grad=False:
  • 升起了隆隆声

  • 分离():
  • 没有提高OOM,但删除了所有渐变,包括net1中的渐变

  • eval():
  • 升起了隆隆声

    有没有办法只计算前端网络(net1)的梯度以提高内存效率?
    如果您有任何建议,我们将不胜感激。

    首先,让我们试着了解您的方法不起作用的原因

  • 此上下文管理器禁用所有渐变计算
  • 由于
    net1
    需要渐变,因此忽略后续的
    需要\u grad=False
  • 如果在该状态下分离,这意味着渐变计算已经停止
  • Eval只是将net2设置为Eval模式,这根本不影响渐变计算
  • 根据您的体系结构,OOM错误可能已经来自于保存计算图中的所有中间值(通常是CNN中的一个问题),也可能来自于必须存储梯度(在完全连接的网络中更常见)

    您可能正在寻找的是所谓的“检查点”,您甚至不必自己实现它,您可以使用pytorch的检查点API,查看


    这基本上允许您分别计算和处理
    net1
    net2
    的梯度。请注意,您确实需要所有渐变信息才能通过
    net2
    ,否则无法计算渐变wrt<代码>网络1

    你完全正确。我需要所有的梯度信息来训练net1。torch.utils.checkpoint对我来说非常有效!谢谢
    z=net1(x)
    with torch.no_grad():
        loss=net2(z)
    
    net2.requires_grad=False
    loss=net2(net1(x))
    
    z=net1(x)
    loss=net2(z).detach()
    
    net2.eval()
    loss=net2(net1(x))