Python 火炬数据集循环太远

Python 火炬数据集循环太远,python,pytorch,Python,Pytorch,为什么这个数据集试图迭代通过最后一个元素 from torch.utils.data.dataset import Dataset class DumbDataset(Dataset): def __init__(self, dct): self.dct = dct self.mapping = dict(enumerate(dct)) def __getitem__(self, index): return self.dct[se

为什么这个数据集试图迭代通过最后一个元素

from torch.utils.data.dataset import Dataset
class DumbDataset(Dataset):
    def __init__(self, dct):
        self.dct = dct
        self.mapping = dict(enumerate(dct))
    def __getitem__(self, index):
        return self.dct[self.mapping[index]]

    def __len__(self):
        print('called')
        return len(self.dct)

ds = DumbDataset({'a': 'aword', 'b': 'another_words'})

for k in ds: print(k)

这将引发KeyError:2,我不理解这一点,因为对象的长度是2。迭代器是否应该在耗尽后停止迭代?

代码引发
KeyError
的原因是
Dataset
,因此,当在for循环中使用时,Python会退回到从索引
0
开始,并调用
\uuu getitem\uuu
,直到引发
索引器,如前所述。您可以修改
数据集
,使其在索引超出范围时引发一个
索引器
,从而像这样工作

def __getitem__(self, index):
    if index >= len(self): raise IndexError
    return self.dct[self.mapping[index]]
然后是你的循环

for k in ds:
    print(k)
我会像你期望的那样工作。另一方面,torch数据集的典型模板是,您可以使用索引在它们之间循环

for i in range(len(ds)):
    k = ds[k]
    print(k)
或者将它们包装在一个
数据加载器中,该加载器将成批返回元素

generator = DataLoader(ds)
for k in generator:
    print(k)