Python 如何在每次转发后释放临时消耗的GPU内存?

Python 如何在每次转发后释放临时消耗的GPU内存?,python,memory-leaks,out-of-memory,gpu,pytorch,Python,Memory Leaks,Out Of Memory,Gpu,Pytorch,我有一门课是这样的: class Stem(nn.Module): def __init__(self): super(Stem, self).__init__() self.out_1 = BasicConv2D(3, 32, kernelSize = 3, stride = 2) self.out_2 = BasicConv2D(32, 32, kernelSize = 3, stride = 1) self.out_

我有一门课是这样的:

class Stem(nn.Module):

    def __init__(self):
        super(Stem, self).__init__()
        self.out_1 = BasicConv2D(3, 32, kernelSize = 3, stride = 2)
        self.out_2 = BasicConv2D(32, 32, kernelSize = 3, stride = 1)
        self.out_3 = BasicConv2D(32, 64, kernelSize = 3, stride = 1, padding = 1)


    def forward(self, x):
        x = self.out_1(x)
        x = self.out_2(x)
        x = self.out_3(x)

        return x
Stem
的属性
out\u 1,2,3
是以下类别的实例:

class BasicConv2D(nn.Module):

    def __init__(self, inChannels, outChannels, kernelSize, stride, padding = 0):
        super(BasicConv2D, self).__init__()
        self.conv = nn.Conv2d(inChannels, outChannels,
                            kernel_size = kernelSize,
                            stride = stride,
                            padding = padding, bias = False)
        self.bn = nn.BatchNorm2d(outChannels,
                                    eps = 0.001,
                                    momentum = 0.1,
                                    affine = True)
        self.relu = nn.ReLU(inplace = False)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        y = self.relu(x)
        return y
训练时,在
Stem.forward()
中,
nvidia smi
告诉您每一行将消耗
x
MBs GPU内存,但在
Stem.forward()
完成后,内存不会被释放,导致训练快速崩溃,GPU内存不足

因此,问题是:
如何释放临时占用的GPU内存?

您的型号看起来不错,因此您可能希望大致了解pytorch如何管理内存分配。我怀疑您只是让指向返回值(y)的指针保持活动状态(例如,通过累积损失或类似的方式)。由于pytorch存储了整个附加的计算图,因此永远不会释放内存


有关更详细的讨论,请参阅,尤其是。

我认为您的pytorch模型没有任何问题。您是否检查了批处理生成过程中是否存在任何泄漏,或者是否检查了存储结果的方式等?我同意@AlStream的说法,问题可能存在于代码的其他部分。例如,您可以将代码发布到培训循环中。或者检查您的生成步骤,如@AlStream Suggered@AlStream,我认为这是因为pytorch默认情况下会在后退之前保留计算图,所以问题是如何在后退之前清除计算图您不能在后退之前清除计算图,因为您需要它来进行后退。您可以尝试使用检查点来减少内存消耗