Python 在Tensorflow 2.0中,迭代无限重复的tf.data数据集的正确方法是什么

Python 在Tensorflow 2.0中,迭代无限重复的tf.data数据集的正确方法是什么,python,tensorflow,Python,Tensorflow,TF2.0文档建议使用python for循环迭代数据集: for step, (x_batch_train, y_batch_train) in enumerate(train_dataset): # do training 问题是,如果数据集无限期地重复(据我所知,这是出于性能原因),这个循环将永远不会结束 我目前正在做的是设置一些我想重复的历史学和训练步骤: train_iter = iter(train_dataset) for i in range(num_epochs):

TF2.0文档建议使用python for循环迭代数据集:

for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
    # do training
问题是,如果数据集无限期地重复(据我所知,这是出于性能原因),这个循环将永远不会结束

我目前正在做的是设置一些我想重复的历史学和训练步骤:

train_iter = iter(train_dataset)
for i in range(num_epochs):
    # do some setup
    for step in range(num_batches):
        (x_batch, y_batch) = next(train_iter)
        # do training
    # log metrics
我不确定的是,这是否会对我的培训过程产生负面影响。这会让我的训练运行得更慢吗?或者我会通过这样运行训练来阻止Tensorflow优化我的代码吗? 最重要的是,设置一个历元中要处理的批数可能有点烦人,因为我想在数据管道中进行随机扩充。因此,我的数据集中唯一样本的数量在不同的培训课程中可能会有所不同。但这不是什么大问题


我试图通过谷歌找到这个问题的答案,但不幸的是没有成功。

代码的问题

train_iter = iter(train_dataset)
for i in range(num_epochs):
    # do some setup
    for step in range(num_batches):
        (x_batch, y_batch) = next(train_iter)
是指每个
历元
模型
以相同的顺序查看
,这是无效的

该代码的输出示例如下所示:

tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
import tensorflow as tf

dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(buffer_size=5, reshuffle_each_iteration=True)
iter(dataset)

buffer_size = 10
batch_size = 2

for epoch in range(num_epochs):
    dataset_epoch = dataset.batch(batch_size)
    for x, y in dataset_epoch:
      print(x,y)
如上所述,每个
历元
对应的值相同,或者换句话说,
在每个
历元
重复(
4,0,8,6,7
3,1,2,9,5
重复三次)

以不同顺序传递批处理的优化有效方法是使用参数,
reshuffle\u each\u iteration=True
。示例代码如下所示:

tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
import tensorflow as tf

dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(buffer_size=5, reshuffle_each_iteration=True)
iter(dataset)

buffer_size = 10
batch_size = 2

for epoch in range(num_epochs):
    dataset_epoch = dataset.batch(batch_size)
    for x, y in dataset_epoch:
      print(x,y)
上述代码的输出如下所示,可以观察到任何批次对应的值都没有重复:

tf.Tensor(2, shape=(), dtype=int64) tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(6, shape=(), dtype=int64)
tf.Tensor(9, shape=(), dtype=int64) tf.Tensor(4, shape=(), dtype=int64)
tf.Tensor(5, shape=(), dtype=int64) tf.Tensor(8, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(7, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(2, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64) tf.Tensor(7, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(8, shape=(), dtype=int64)
tf.Tensor(9, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(4, shape=(), dtype=int64)

希望这有帮助。学习愉快

在TF2.0中,通常不会无限期地重复数据集,而是在每个历元中创建一个新的迭代器。你试过了吗?看到性能下降了吗?好的,很高兴知道,谢谢!我没有特别注意到这方面的性能下降/还没有进一步调查。目前,我只是对我的数据管道性能进行一般性的故障排除,并试图使其更高效。