Python 使用PyTorch加载模型、优化器和调度器后,培训损失增加

Python 使用PyTorch加载模型、优化器和调度器后,培训损失增加,python,pytorch,Python,Pytorch,我在一个研究计算机视觉项目中工作,我有一个特殊的问题,在崩溃或中断后不允许我恢复训练,这是我加载检查点的代码: from torch.optim import lr_scheduler N_EPOCHS = 120 if load_weights: optimizer = torch.optim.Adam(model.parameters(),lr=checkpoint['last_lr'], weight_decay=weight_decay) optimizer.load_s

我在一个研究计算机视觉项目中工作,我有一个特殊的问题,在崩溃或中断后不允许我恢复训练,这是我加载检查点的代码:

from torch.optim import lr_scheduler
N_EPOCHS = 120
if load_weights:
    optimizer = torch.optim.Adam(model.parameters(),lr=checkpoint['last_lr'], weight_decay=weight_decay)
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler = lr_scheduler.StepLR(optimizer, step_size=step_of_scheduler, gamma=0.9, last_epoch=loaded_epochs)
    scheduler.load_state_dict(checkpoint['scheduler'])
else:
    optimizer = torch.optim.Adam(model.parameters(),lr=initial_lr, weight_decay=weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=step_of_scheduler, gamma=0.9)
我甚至尝试添加
optimizer.state_dict()['param\u groups'][0]['params']=checkpoint['optimizer']['param\u groups'][0]['params']
,但结果更糟,我在验证后保存检查点的代码是:

# Checkpoint
       checkpoint = {'model':model.state_dict(),
                     'epoch':loaded_epochs + epoch + 1,
                     'last_validation_acc': val_acc_db_avg[-1],
                     'hyperparameters': hyperparameters,
                     'last_lr': optimizer.param_groups[0]['lr'],
                     'best_general_val': max(val_acc_db_avg),
                     'last_train_loss': train_loss_db[-1],
                     'optimizer': optimizer.state_dict(),
                     'scheduler': scheduler.state_dict(),
                     'img_resize': 512 if lr_scaled else 256,
                     'best_train_acc': max(train_acc_db),
                     'lr_scaled': lr_scaled
                     }

       # Backup in drive
       torch.save(checkpoint, filename_of_checkpoint)
我确实在训练开始前调用了model.train(),还调用了scheduler.step(),也就是说,我的训练损失从0.81增加到了0.89或0.90。如果我手动分配参数,我不知道这里会出什么问题,因为我也正确加载了模型状态dict

编辑:这是我加载模型状态dict的方式:

if load_weights:
    model.load_state_dict(checkpoint_load['model'])

# Transfering model to GPU if available
model = model.to(DEVICE)

将权重加载到模型中如何?您在给定的代码中没有提到这一点。我添加了如何加载权重的代码,这发生在定义了模型的架构之后。如何将权重加载到模型中?您在给定的代码中没有提到这一点。我添加了如何加载权重的代码,这发生在定义模型架构之后