迭代器重置时的tensorflow数据集洗牌行为

迭代器重置时的tensorflow数据集洗牌行为,tensorflow,tensorflow-datasets,Tensorflow,Tensorflow Datasets,我发现reshuffle\u每次迭代参数到tf.Dataset.shuffle有点混乱。考虑这个代码: import tensorflow as tf flist = ["trimg1", "trimg2", "trimg3", "trimg4"] filenames = tf.constant(flist) train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames)) train_x_dataset = train_x

我发现
reshuffle\u每次迭代
参数到
tf.Dataset.shuffle
有点混乱。考虑这个代码:

import tensorflow as tf

flist = ["trimg1", "trimg2", "trimg3", "trimg4"]

filenames = tf.constant(flist)

train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.shuffle(buffer_size=10, reshuffle_each_iteration=False)

it_train_x = train_x_dataset.make_initializable_iterator()

next_sample = it_train_x.get_next()

with tf.Session() as sess:
    for epoch in range(3):
        sess.run(it_train_x.initializer)
        print("Starting epoch ", epoch)
        while True:
            try:
                s = sess.run(next_sample)
                print("Sample: ", s)
            except tf.errors.OutOfRangeError:
                break
代码输出:

Starting epoch  0
Sample:  b'trimg2'
Sample:  b'trimg1'
Sample:  b'trimg3'
Sample:  b'trimg4'
Starting epoch  1
Sample:  b'trimg4'
Sample:  b'trimg3'
Sample:  b'trimg2'
Sample:  b'trimg1'
Starting epoch  2
Sample:  b'trimg3'
Sample:  b'trimg2'
Sample:  b'trimg4'
Sample:  b'trimg1'
即使每次迭代的
reshuffle\u
都是
False
,tensorflow仍然会在数据集迭代后重新洗牌。是否有其他方法重置迭代器?每次迭代时,
重新洗牌的预期行为是什么

我知道我可以修复
种子
,每次都得到相同的顺序,问题是
如何重新调整每个迭代
的工作


我还知道,更惯用的方法是使用
repeat()
,但在我的例子中,每个历元的实际样本数都会不同。

我怀疑TensorFlow仍然会在for循环的每次迭代中重新排列数据集,因为迭代器在每次迭代中都会初始化。每次初始化迭代器时,都会对数据集应用shuffle函数

预期的行为是迭代器初始化一次,
reshuffle\u每次迭代
允许您选择是否在重复数据时重新洗牌(每次对原始数据进行迭代)

我不知道如何重新编写代码以处理可变数量的样本,但以下是使用
repeat()
函数修改的代码,以证明我的声明:

flist = ["trimg1", "trimg2", "trimg3", "trimg4"]

filenames = tf.constant(flist)

train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.shuffle(buffer_size=10, reshuffle_each_iteration=False)
train_x_dataset = train_x_dataset.repeat(4)

it_train_x = train_x_dataset.make_initializable_iterator()

next_sample = it_train_x.get_next()

with tf.Session() as sess:
    sess.run(it_train_x.initializer)
    while True:
        try:
            s = sess.run(next_sample)
            print("Sample: ", s)
        except tf.errors.OutOfRangeError:
            break
输出:

Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
然而,如果我设置
reshuffle\u each\u iteration=True
,我会得到:

Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Sample:  trimg2
Sample:  trimg1
Sample:  trimg4
Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg4
Sample:  trimg1
Sample:  trimg2
Sample:  trimg3
希望这有帮助

编辑:我的主张的进一步证据:这些在TensorFlow代码库中。在本例中,使用一次迭代,因此只初始化一次。批量大小为10的数据用于大小为10的数据,因此每次调用
迭代器.get_next()
都会遍历整个源数据。代码检查该函数的每个后续调用是否返回相同的(无序)数组

关于的讨论可能会进一步阐明不同迭代器的预期用途和预期行为,并可能帮助您找到特定问题的解决方案