Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/298.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 创建为RNN返回数据序列的Pytorch数据集的正确方法?_Python_Deep Learning_Dataset_Pytorch_Rnn - Fatal编程技术网

Python 创建为RNN返回数据序列的Pytorch数据集的正确方法?

Python 创建为RNN返回数据序列的Pytorch数据集的正确方法?,python,deep-learning,dataset,pytorch,rnn,Python,Deep Learning,Dataset,Pytorch,Rnn,我正在尝试对RNN进行时间序列数据的培训,虽然有很多关于如何构建RNN模型的教程,但我在为这项任务构建dataloader对象时遇到了一些问题。所有数据都将是相同的长度,因此也不需要填充。到目前为止,我采用的方法是在dataset类的getitem函数中返回一系列数据,并将长度定义为 len(data) - seq_len + 1 然而,我觉得这有点“黑客”,应该有一个更合适的方法来做到这一点。这种方法似乎令人困惑,我觉得如果与团队合作,就会产生问题。更具体地说,我认为以某种方式重写Pytor

我正在尝试对RNN进行时间序列数据的培训,虽然有很多关于如何构建RNN模型的教程,但我在为这项任务构建dataloader对象时遇到了一些问题。所有数据都将是相同的长度,因此也不需要填充。到目前为止,我采用的方法是在dataset类的getitem函数中返回一系列数据,并将长度定义为

len(data) - seq_len + 1
然而,我觉得这有点“黑客”,应该有一个更合适的方法来做到这一点。这种方法似乎令人困惑,我觉得如果与团队合作,就会产生问题。更具体地说,我认为以某种方式重写Pytorch数据集构造函数中的sampler函数是正确的方法,但我在理解如何实现这一点上遇到了困难。下面是我构建的当前数据集类,有人能告诉我如何修复它吗?先谢谢你

class CustomDataset(Dataset):
  def __init__(self, df, cats, y, seq_l):
    self.n, self.seq_l = len(df), seq_l
    self.cats  = np.array(np.stack([c.values for n,c in df[cats].items()], 1).astype(np.int64))
    self.conts = np.array(np.stack([c.values for n,c in df[[i for i in df.columns if i not in cats]].items()], 1).astype(np.float32))
    self.y = np.array(y)

  def __len__(self): return len(self.y) - self.seq_l + 1

  def __getitem__(self, idx):
    return [
      (torch.from_numpy(self.cats[idx:idx+self.seq_l]),
      torch.from_numpy(self.conts[idx:idx+self.seq_l])),
      self.y[idx+self.seq_l-1]
    ]

如果我理解正确的话,您有时间序列数据,并且您希望通过从中采样来包装具有相同长度的数据批次? 我认为您可以使用Dataset只返回一个数据样本,正如PyTorch开发人员最初打算的那样。您可以使用自己的\u collate\u fn函数将它们堆叠在批处理中,并将其传递给DataLoader类(\u collate\u fn是一个可调用函数,它获取样本列表并返回批处理,通常,例如,填充在那里完成)。因此,您不会有长度的依赖关系(=数据集类中的批大小)。我假设您希望在形成批次时保留样本的顺序(假设您使用时间序列),您可以编写自己的采样器类(或者使用PyTorch中已有的SequentialSampler)。 因此,您将解耦示例表示,在批处理(DataLoader中的_collate _fn)和采样(Sampler类)中形成它们。希望这能有所帮助。

collate\u fn(可调用,可选)–合并一个样本列表以形成一个小批量。
Wow看起来我不知何故错过了这一点,是的,这一切都很有意义,似乎是按照Pytorch开发人员的方式实现的。非常感谢。