如何在Pytorch中为手动优化器清除渐变

如何在Pytorch中为手动优化器清除渐变,pytorch,Pytorch,我想做的是计算投影梯度上升w.r.t输入X(在我的例子中是图像),同时禁用参数梯度 这是我的密码: def func(x: torch.Tensor, y: torch.Tensor, network: nn.Module, loss_func: nn.Module, eta: float = .1, steps: int = 10): network.requires_grad_(False)#don't calculate grads

我想做的是计算投影梯度上升w.r.t输入X(在我的例子中是图像),同时禁用参数梯度

这是我的密码:

def func(x: torch.Tensor, y: torch.Tensor,
                       network: nn.Module, loss_func: nn.Module, eta: float = .1, steps: int = 10):

    network.requires_grad_(False)#don't calculate grads for the parameters of the network
    x_copy = x.clone().requires_grad_(True)#copy from the main input 
    for i in range(steps):  
        pred = network(x_copy)#calculating predictions for the network
        loss = loss_func(pred,y)#calculation of the loss
        x_copy.retain_grad()#retain gradients of the input
        
        loss.backward()
        
        x_copy = x_copy + eta * x_copy.grad.sign()#gradient calculation


    return x_copy
        
通常,为了清除梯度,我们在每次
loss.backward()
之前执行
optimizer.zero\u grad()
。 但是如何使用手动梯度优化器呢?现在我得到一个错误:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time