Python 3.x PyTorch:用于培训和测试/验证的不同正向方法

Python 3.x PyTorch:用于培训和测试/验证的不同正向方法,python-3.x,neural-network,pytorch,transformer,seq2seq,Python 3.x,Neural Network,Pytorch,Transformer,Seq2seq,我目前正在尝试扩展基于FairSeq/PyTorch的功能。在训练期间,我需要训练两个编码器:一个是目标样本,另一个是源样本 因此,当前正向函数如下所示: def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs): encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) d

我目前正在尝试扩展基于FairSeq/PyTorch的功能。在训练期间,我需要训练两个编码器:一个是目标样本,另一个是源样本

因此,当前正向函数如下所示:

def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out
def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
    concat = some_concatination_func(encoder_out, autoencoder_out)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
    return decoder_out
import torch


class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
        ...

    # You could split it into two functions but both should be called by forward
    def forward(
        self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
    ):
        encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
        if self.train:
            return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
        autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
        concat = some_concatination_func(encoder_out, autoencoder_out)
        return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
基于此,我想要这样的东西:

def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out
def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
    concat = some_concatination_func(encoder_out, autoencoder_out)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
    return decoder_out
import torch


class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
        ...

    # You could split it into two functions but both should be called by forward
    def forward(
        self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
    ):
        encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
        if self.train:
            return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
        autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
        concat = some_concatination_func(encoder_out, autoencoder_out)
        return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
有没有办法做到这一点

编辑: 由于我需要扩展FairSeqEncoderModel,因此我有以下限制:

编辑2:
在Fairseq中传递给forward函数的参数可以通过实现您自己的标准来更改,例如,请参见,其中
sample['net\u input']
被传递给模型的
\u call\u
函数,该函数调用
forward
方法。

默认情况下,调用
模型()
invoke
forward
方法,该方法在您的案例中是向前训练的,因此您只需要在模型类中为测试/评估路径定义新方法,如下所示:

代码:

类FooBar(nn.Module):
“”“用于测试/调试的虚拟网络。”。
"""
定义初始化(自):
super()。\uuuu init\uuuuu()
...
def前进(自身,x):
#这里将有火车进站
...
def蒸发测试(自身,x):
#这将是评估/测试前进
...
示例:

model=FooBar()#初始化模型
#火车时刻
pred=model(x)#调用引擎盖下的forward()方法
#测试/评估时间
测试前=模型评估测试(x)
评论:
我建议您将这两个正向路径拆分为两个单独的方法,因为这样更易于调试并避免反向传播时可能出现的问题。

首先,您应该始终使用并定义
正向
而不是在
torch.nn.Module
实例上调用的其他方法

绝对不要重载
eval()
,如PyTorch()定义的as评估方法所示。
此方法允许将模型内的层置于评估模式(例如,对层的特定更改,如
退出的推断模式
批处理规范

此外,您应该使用
\uuuu call\uuuu
魔术方法调用它。为什么?因为钩子和其他PyTorch特定的东西是以这种方式正确注册的

其次,不要使用建议的一些外部
模式
字符串变量。这就是PyTorch中的
train
变量的作用,它是区分模型处于
eval
模式还是
train
模式的标准

也就是说,你最好这样做:

def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out
def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
    concat = some_concatination_func(encoder_out, autoencoder_out)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
    return decoder_out
import torch


class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
        ...

    # You could split it into two functions but both should be called by forward
    def forward(
        self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
    ):
        encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
        if self.train:
            return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
        autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
        concat = some_concatination_func(encoder_out, autoencoder_out)
        return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)

您可以(也可以说应该)将上述方法分为两个单独的方法,但这并不太糟糕,因为这样的函数非常简短且可读。只要坚持PyTorch的处理方式(如果可能的话),而不是一些临时解决方案。不,反向传播不会有问题,为什么会有呢?

我不知道问题是什么。只需在模型类中提出这两个函数。然后,在训练时,在输入上使用
model.forward\u train
。在测试时,在输入上使用
model.forward\u test
。请注意,在这种情况下,您不能执行
model(input)
,因为PyTorch认为这等同于
model.forward(input)
,因此会抛出一个错误。哦,是的,您是对的,我的坏--
eval
已保留,我必须更改它,谢谢!谢谢@Szymon Maszke,我不知道这个变量。我只是合乎逻辑地思考了一下,然后想出了一个方法,这样就只需要编写前进函数。但是非常感谢。@Szymon Maszke非常感谢你的回答!列车变量非常有用。但是,我不确定如何将目标令牌传递给forward方法,因为我使用fairseq train命令进行训练,该命令只传递源令牌。@qwertz我不确定是否遵循。如果您不想使用JIT导出代码,至少可以将任何变量传递给
forward
\uuuu init\uuuu
。什么是
fair seq
列车指令?你能提供一个例子和一些来源的链接吗?它与标准PyTorch有何不同?@Szymon Maszke我正在实现Fairseq框架中的类“FairseqEncoderDecoderModel”。我使用这个命令进行训练:所以我自己并不实际处理数据加载,也不知道fairseq train cmd命令最终在哪里调用forward方法。