Python 如何将PyTorch子模块保持在评估模式?

Python 如何将PyTorch子模块保持在评估模式?,python,pytorch,Python,Pytorch,我有一个训练前的模型,我正在和一个正在训练的模型一起使用。我希望预训练模型始终处于eval模式,但另一个模型将在eval和train模式之间来回移动。不过,我仍然希望预训练模型是另一个模型的子模块(例如,使所有参数保持在同一个设备上)。有办法做到这一点吗?下面是一个简单的例子: 来自火炬导入nn的 类固定模块(nn.模块): 通过 类可训练模块(nn.模块): def u u初始(自修复模块): super()。\uuuu init\uuuuu() self.fixed\u模块=fixed\u模

我有一个训练前的模型,我正在和一个正在训练的模型一起使用。我希望预训练模型始终处于eval模式,但另一个模型将在eval和train模式之间来回移动。不过,我仍然希望预训练模型是另一个模型的子模块(例如,使所有参数保持在同一个设备上)。有办法做到这一点吗?下面是一个简单的例子:

来自火炬导入nn的

类固定模块(nn.模块):
通过
类可训练模块(nn.模块):
def u u初始(自修复模块):
super()。\uuuu init\uuuuu()
self.fixed\u模块=fixed\u模块
fixed=FixedModule().eval()
断言不是固定的
可培训=可培训模块(固定)
断言可培训。培训和不可培训。固定\u模块。培训
可训练的
assert trainable.fixed_module.training#我想让它给出一个错误
我知道我可以解决这个问题,例如,一直做

trainable.train()
可培训的。固定的_模块。评估()

但这很容易出错,并且不能很好地与现有代码配合使用。

一种解决方案是覆盖
train
,如下所示:

from torch import nn

class FixedModule(nn.Module):
    pass

class TrainableModule(nn.Module):
    def __init__(self, fixed_module):
        super().__init__()
        self.fixed_module = fixed_module

    def train(self):
        super().train()
        self.fixed_module.eval()

fixed = FixedModule().eval()
assert not fixed.training

trainable = TrainableModule(fixed)
assert trainable.training and not trainable.fixed_module.training

trainable.train()
assert trainable.fixed_module.training  # This gives an error now

您可以在
FixedModule
中覆盖
train
,以防止其改变模式。请注意,
eval
只调用
train(False)
,因此您不需要覆盖它。但是调用
FixedModule.eval
现在什么也做不了,所以必须在init中设置
training=False

来自火炬导入nn的

类固定模块(nn.模块):
定义初始化(自):
super()。\uuuu init\uuuuu()
自我训练=错误
#在调用self.children之前,在此处添加任何其他nn.Module属性
#如果你真的愿意,你也可以在每个孩子身上覆盖“train”,
#但除非有外部参照,否则这似乎有些过分
#到FixedModule的任何子模块
对于self.children()中的模块:
模块eval()
def序列(自模式):
回归自我
类可训练模块(nn.模块):
def u u初始(自修复模块):
super()。\uuuu init\uuuuu()
self.fixed\u模块=fixed\u模块
fixed=FixedModule().eval()
断言不是固定的
可培训=可培训模块(固定)
断言可培训。培训和不可培训。固定\u模块。培训
可训练的
断言不可培训。已修复模块。培训#通过

出于某种原因,我想避免为
固定模块
切换train/eval,但我检查了源代码,并且
train
/
eval
只设置了一个布尔标志(递归地在子模块上),因此没有真正的理由避免。肯定是被否决了,但我可能会接受我的答案,因为
FixedModule
自己处理这个逻辑会更好,这样其他类就不用担心了。