Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/349.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python pytorch预训练Resnet与Imagenet的精度不同_Python_Deep Learning_Pytorch_Resnet_Pre Trained Model - Fatal编程技术网

Python pytorch预训练Resnet与Imagenet的精度不同

Python pytorch预训练Resnet与Imagenet的精度不同,python,deep-learning,pytorch,resnet,pre-trained-model,Python,Deep Learning,Pytorch,Resnet,Pre Trained Model,CPU:intel Zeon GPU:NVIDIA RTX2080Ti python:3.6 火炬版本:1.4 torchvision版本:0.5.0 我下载ILSVRC2012(Imagenet数据集)torrent,并进行评估验证集 选择的模型是torchvision.datasets.Resnet50(pretrained=True),我希望验证top 1的准确率超过60% 我在下面的链接中找到了验证集的更正标签 并与下面的链接进行交叉验证 我为验证集和评估定制了数据集 但与我的预期相反

CPU:intel Zeon GPU:NVIDIA RTX2080Ti python:3.6 火炬版本:1.4 torchvision版本:0.5.0

我下载ILSVRC2012(Imagenet数据集)torrent,并进行评估验证集

选择的模型是torchvision.datasets.Resnet50(pretrained=True),我希望验证top 1的准确率超过60%

我在下面的链接中找到了验证集的更正标签

并与下面的链接进行交叉验证

我为验证集和评估定制了数据集

但与我的预期相反,评估的准确度非常低,就好像它没有被学习过一样。 (准确度为0.1%~0.05%)

这是我的测试代码和上传的github

这是我的代码,但运行准确率为0.01%

请让我知道我的代码的问题

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import PIL.Image as Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import torchvision
from tqdm import tqdm
import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([transforms.Resize((256, 256)),
                                transforms.CenterCrop((224, 224)),
                              transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])])

class val_imagenet_dataset(Dataset):   
    def __init__(self, root="../imagenet/", transform=None):
        self.transform = transform
        self.root = root
        self.data = []
        self.target = []
        f = open(self.root+"/ILSVRC2012_validation_ground_truth.txt")
        label_data = f.readlines()
        file_list = os.listdir(self.root)
        image_list = [image for image in file_list if image.endswith(".JPEG")]
        image_list.sort()
        for idx, label in enumerate(tqdm(label_data)):
            self.data.append(image_list[idx])
            self.target.append(int(label)-1) 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path, target = self.data[idx], self.target[idx]

        with open(self.root+path, 'rb') as f:
            img = Image.open(f)
            sample = img.convert('RGB')


        if self.transform is not None:
            sample = self.transform(sample)

        return sample, target


def val_model(model, dataloader, device, criterion =None):
    since = time.time()
    acc = 0
    loss = 0
    data_size = dataloader.dataset.__len__()
    for inputs, labels in tqdm(dataloader):
        model.eval()
        model.to(device)
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        acc += torch.sum(preds==labels.data)

        if criterion is not None:
            _loss = criterion(outputs, labels)
            loss +=_loss.item() * inputs.size(0)
            print("local loss : {:4f}".format(loss))

    acc = acc.double()/data_size
    loss = loss / data_size

    print("Acc : {:4f}".format(acc))
    if criterion is not None:
        print("Loss : {:4f}".format(loss))

val_imagenet = val_imagenet_dataset(transform=transform)
val_dataloader = DataLoader(val_imagenet, batch_size=1, shuffle=False, num_workers=2)

resnet_50_origin = torchvision.models.resnet50(pretrained=True)

val_model(resnet_50_origin, val_dataloader, device=device)