Python PyTorch:RuntimeError:DETRModel的加载状态下出错

Python PyTorch:RuntimeError:DETRModel的加载状态下出错,python,pytorch,Python,Pytorch,我试图用resnet50预训练模型加载,但出现以下错误 RuntimeError: Error(s) in loading state_dict for DETRModel: Missing key(s) in state_dict: "model.transformer.encoder.layers.0.self_attn.in_proj_weight", "model.transformer.encoder.layers.0.self_attn.in_p

我试图用resnet50预训练模型加载,但出现以下错误

RuntimeError: Error(s) in loading state_dict for DETRModel:
    Missing key(s) in state_dict: "model.transformer.encoder.layers.0.self_attn.in_proj_weight", "model.transformer.encoder.layers.0.self_attn.in_proj_bias", "model.transformer.encoder.layers.0.self_attn.out_proj.weight", "model.transformer.encoder.layers.0.self_attn.out_proj.bias", "model.transformer.encoder.layers.0.linear1.weight", "model.transformer.encoder.layers.0.linear1.bias" ........
我使用的是一个公共数据集,其中包含DETR的所有权重和模型。我正在使用的代码

这是我正在使用的代码:

import sys
import torch.nn as nn
from pathlib import Path
sys.path.append('../input/detrmodels/facebookresearch_detr_master')

# copy pretrained weights to the folder PyTorch will search by default
Path('/root/.cache/torch/hub/').mkdir(exist_ok=True, parents=True)
Path('/root/.cache/torch/hub/checkpoints/').mkdir(exist_ok=True, parents=True)

detr_path = '/root/.cache/torch/hub/checkpoints/detr-r50-e632da11.pth'
resnet50_pretrained = '/root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth'
detr_hub = '/root/.cache/torch/hub/facebookresearch_detr_master'

!cp ../input/detrmodels/detr-r50-e632da11.pth {detr_path}
!cp ../input/detrmodels/resnet50-19c8e357.pth {resnet50_pretrained}
!cp -R ../input/detrmodels/facebookresearch_detr_master {detr_hub}

DIR_INPUT = '../input/shopee-product-matching'
WEIGHTS_FILE = '../input/detrmodels/resnet50-19c8e357.pth'

class DETRModel(nn.Module):
    def __init__(self, num_classes, num_queries, model_name='detr_resnet50'):
        super(DETRModel, self).__init__()
        self.num_classes = num_classes
        self.num_queries = num_queries

        self.model = torch.hub.load('facebookresearch/detr', model_name, pretrained=True)
     
        self.in_features = self.model.class_embed.in_features

        self.model.class_embed = nn.Linear(in_features=self.in_features,
                                           out_features=self.num_classes)
        self.model.num_queries = self.num_queries

    def forward(self, images):
        return self.model(images)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 11011
num_queries = 100
model = DETRModel(num_classes=N_CLASSES, num_queries=num_queries, model_name='detr_resnet50')
model = model.to(device)
model.load_state_dict(torch.load(WEIGHTS_FILE))
model.eval()

如何解决这个问题?

似乎应该改为
model.model.load\u state\u dict(torch.load(WEIGHTS\u FILE))