如何使用批处理规范而不忘记刚刚在Pytorch中使用的批处理统计信息?

如何使用批处理规范而不忘记刚刚在Pytorch中使用的批处理统计信息?,pytorch,Pytorch,我处在一个不寻常的环境中,我不应该使用跑步统计数据(因为这会被视为作弊,例如元学习)。但是,我经常对一组点(实际上是5点)进行向前传递,然后我只想使用前面的统计数据对1点进行评估,但是batch norm忘记了它刚才使用的批次统计数据。我试图硬编码它应该是的值,但我得到了奇怪的错误(即使我取消了pytorch代码本身的注释,比如检查维度大小) 如何硬编码以前的批次统计信息,以便批次规范在新的单个数据点上工作,然后为新的下一批次重置它们 注意:我不想更改批处理规范层类型 我尝试过的示例代码: de

我处在一个不寻常的环境中,我不应该使用跑步统计数据(因为这会被视为作弊,例如元学习)。但是,我经常对一组点(实际上是5点)进行向前传递,然后我只想使用前面的统计数据对1点进行评估,但是batch norm忘记了它刚才使用的批次统计数据。我试图硬编码它应该是的值,但我得到了奇怪的错误(即使我取消了pytorch代码本身的注释,比如检查维度大小)

如何硬编码以前的批次统计信息,以便批次规范在新的单个数据点上工作,然后为新的下一批次重置它们

注意:我不想更改批处理规范层类型

我尝试过的示例代码:

def set_tracking_running_stats(model):
    for attr in dir(model):
        if 'bn' in attr:
            target_attr = getattr(model, attr)
            target_attr.track_running_stats = True
            target_attr.running_mean = torch.nn.Parameter(torch.zeros(target_attr.num_features, requires_grad=False))
            target_attr.running_var = torch.nn.Parameter(torch.ones(target_attr.num_features, requires_grad=False))
            target_attr.num_batches_tracked = torch.nn.Parameter(torch.tensor(0, dtype=torch.long), requires_grad=False)
            # target_attr.reset_running_stats()
    return
我最喜欢的评论错误是:

    raise ValueError('expected 2D or 3D input (got {}D input)'
ValueError: expected 2D or 3D input (got 1D input)

pytorch论坛:

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)