Deep learning 如何为DistributedDataparallel模型加载pth文件?

Deep learning 如何为DistributedDataparallel模型加载pth文件?,deep-learning,parallel-processing,pytorch,Deep Learning,Parallel Processing,Pytorch,我使用DistributedDataParallel对模型进行了训练,并生成了pth文件 if args.gpu is not None: print('Gpu setting...',args.gpu) torch.cuda.set_device(args.gpu) model.cuda(args.gpu) # When using a single GPU per process and pe

我使用DistributedDataParallel对模型进行了训练,并生成了pth文件

  if args.gpu is not None:
            print('Gpu setting...',args.gpu)
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu]
                #,output_device=[args.gpu]
                ,find_unused_parameters=True)
然后我尝试评估模型

 self.model = EfficientDet(num_classes=num_class,
                                  network=network,
                                  W_bifpn=EFFICIENTDET[network]['W_bifpn'],
                                  D_bifpn=EFFICIENTDET[network]['D_bifpn'],
                                  D_class=EFFICIENTDET[network]['D_class'],
                                  is_training=False
                                  )

        #self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[args.gpu],find_unused_parameters=True)
        #self.model =  torch.nn.parallel.DistributedDataParallel(self.model)

        if(self.weights is not None):
            print('load state dic...',self.weights)
            checkpoint = torch.load(
                self.weights, map_location=lambda storage, loc: storage)
            state_dict = checkpoint['state_dict']
            self.model.load_state_dict(state_dict)
        if torch.cuda.is_available():
            self.model = self.model.cuda()
        self.model.eval()
这将导致密钥丢失错误。如何加载DistributedDataParallel训练过的pth文件