Python类型错误:';张量';对字典排序时,对象不可调用
这是我的密码。导入的包未显示。我正在尝试将CIFAR-10测试数据输入alexnet。最后的字典需要分类,这样我才能找到最常见的分类。请帮忙,我什么都试过了 徖Python类型错误:';张量';对字典排序时,对象不可调用,python,pytorch,Python,Pytorch,这是我的密码。导入的包未显示。我正在尝试将CIFAR-10测试数据输入alexnet。最后的字典需要分类,这样我才能找到最常见的分类。请帮忙,我什么都试过了 徖 alexnet = models.alexnet(pretrained=True) transform = transforms.Compose([ #[1] transforms.Resize(256), #[2] transforms.CenterCrop(224),
alexnet = models.alexnet(pretrained=True)
transform = transforms.Compose([ #[1]
transforms.Resize(256), #[2]
transforms.CenterCrop(224), #[3]
transforms.ToTensor(), #[4]
transforms.Normalize( #[5]
mean=[0.485, 0.456, 0.406], #[6]
std=[0.229, 0.224, 0.225] #[7]
)])
# Getting the CIFAR-10 dataset
dataset = CIFAR10(root='data/', download=True, transform=transform)
test_dataset = CIFAR10(root='data/', train=False, transform=transform)
classes = dataset.classes
#print(classes)
torch.manual_seed(43)
val_size = 10000
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])
#print(len(train_ds), len(val_ds))
batch_size=100
train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size, num_workers=8, pin_memory=True)
with open("/home/shaan/Computer Science/CS4442/Ass4/imagenet_classes.txt") as f:
classes = eval(f.read())
holder = []
dic = {}
current = ''
#data_iter = iter(test_loader)
#images,labels = data_iter.next()
#alexnet.eval()
with torch.no_grad():
for data in test_loader:
images, labels = data
out = alexnet(images)
#print(out.shape)
for j in range(0,batch_size):
sorted, indices = torch.sort(out,descending=True)
percentage = F.softmax(out,dim=1)[j]*100
results = [(classes[i.item()],percentage[i].item()) for i in indices[j][:5]]
holder.append(results[0][0])
holder.sort()
for z in holder:
if current != z:
count = 1
dic[z] = count
current = z
else:
count = count + 1
dic[z] = count
current = z
这就是我得到错误的地方:
for w in sorted(dic, key=dic.get, reverse=True):
print(w, dic[w])
这条线就是问题所在
sorted, indices = torch.sort(out,descending=True)
您创建了一个名为sorted
的变量,该变量的名称与出错时调用的sorted
函数的名称完全相同
把这个换成其他类似的东西
sorted_out, indices = torch.sort(out,descending=True)