Python PyTorch:为什么我的dataset类给出了索引超出范围的错误?

Python PyTorch:为什么我的dataset类给出了索引超出范围的错误?,python,pytorch,h5py,Python,Pytorch,H5py,我试图弄清楚为什么我的数据集出现了超出范围的索引错误 考虑以下火炬数据集: # prepare torch data set class MSRH5Processor(torch.utils.data.Dataset): def __init__(self, type, shard=False, **args): # init configurable string self.type = type # init shard for sam

我试图弄清楚为什么我的数据集出现了超出范围的索引错误

考虑以下火炬数据集:

# prepare torch data set
class MSRH5Processor(torch.utils.data.Dataset):
    def __init__(self, type, shard=False, **args):
        # init configurable string
        self.type = type
        # init shard for sampling large ds if specified
        self.shard = shard
        # set seed if given
        self.seed = args
        # set loc
        self.file_path = 'C:\\data\\h5py_embeds\\'
        # set file paths
        self.val_embed_path = self.file_path + 'msr_dev_bert_embeds.h5'

        # if true, initialize the dev data
        if self.type == 'dev':
            # embeds are shaped: [layers, tokens, features]
            self.embeddings = h5py.File(self.val_embed_path, 'r')["embeds"]

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        if self.type == 'dev':
            sample = {'embeddings': self.embeddings[idx]}
            return sample

# load dataset
processor = MSRH5Processor(type='dev', shard=False)
# check length
len(processor)  # 22425

# iterate over the samples
count = 0
for step, batch in enumerate(processor):
    count += 1
# error: Index (22425) out of range (0-22424)

with h5py.File('C:\\w266\\h5py_embeds\\msr_dev_bert_embeds.h5', 'r') as f:
    print(f['embeds'].attrs['last_index'])  # 22425
    print(f['embeds'].shape)  # (22425, 128, 768)
    print(len(f['embeds']))  # 22425
如果我手动将数据集长度更改为
100
22424
,我仍然会得到相同的错误。是什么让PyTorch查找索引22425

如果我要创建一个CSV数据集,包含1000个观察值(其中
len=1000
),它将停止在999而不是1000处将索引输入
\uuuuu getitem\uuuuuuu()
方法

编辑:
似乎只有Dataset类和H5py文件存在问题。如果我使用torch数据加载器,它将运行到我的数据集的自然长度。尽管如此,我很想知道Torch是如何为我的H5文件获取这个数字的,这导致它的行为与CSV不同。

要将
数据集
用作iterable,您必须实现
\uu iter\uuuuu
方法或
\uu getitem\uuuuu
具有序列语义。当方法
\uuu getitem\uuu
为某些索引idx引发
索引器
时,迭代停止

数据集的问题在于:

self.embeddings=h5py.File(self.val_embed_path,'r')[“embeds”]
实际上是
h5py.\u hl.dataset.dataset
类型,在索引外请求时引发
ValueError


您必须在类构造函数中加载整个嵌入,以便在索引外访问numpy数组将引发
索引器
,或者在
中的
值错误
上重新抛出
索引器
,我发布了一个答案,但对于未来来说,有一个完整的stacktrace将是非常好的,因为它将节省我们很多时间