PyTorch图中的部分向后

PyTorch图中的部分向后,pytorch,Pytorch,我有一个中等大小的张量x。在这个中等大小的张量上,应用一个计算昂贵的函数(向前和向后)q来获得另一个中等大小的张量y 使用y I计算许多函数来产生一个标量,它们在计算上并不特别昂贵,但是使用了大的内部状态,从而产生了一个大的计算图 现在我想用下面的方法计算x上的梯度 y = q(x) for f in functions res += f(y) res.backward() 这个实现的问题是保留了所有函数f的图。这会导致总内存使用量激增 另一种可能是计算 y = q(x) for

我有一个中等大小的张量x。在这个中等大小的张量上,应用一个计算昂贵的函数(向前和向后)q来获得另一个中等大小的张量y

使用y I计算许多函数来产生一个标量,它们在计算上并不特别昂贵,但是使用了大的内部状态,从而产生了一个大的计算图

现在我想用下面的方法计算x上的梯度

y = q(x)

for f in functions
    res += f(y)

res.backward()
这个实现的问题是保留了所有函数f的图。这会导致总内存使用量激增

另一种可能是计算

y = q(x)

for f in functions
    partial = f(y)
    partial.backward(retain_graph = True)
其优点是,每次函数求值f后,结果都超出范围,图形被释放,从而节省了大量内存。然而,在这种情况下,函数q(x)被向后计算多次,这是非常耗时的


在理想情况下,我希望首先使用类似于第二个示例的代码计算y的梯度,然后只向后计算一次q以获得x的梯度。使用Pytork的正确方法是什么?

我认为这将是实现这一目标的方法:

y = q(x)
z = y.detach()
z.requires_grad_(True)

for f in functions:
    partial = f(y)
    partial.backward(retain_graph = True)
y.backward(z.grad)
z
中累积所有梯度,这
y
但是在另一个计算图形中,然后在第一个图形中传播这些梯度(
z.grad