迭代器重置时的tensorflow数据集洗牌行为
我发现迭代器重置时的tensorflow数据集洗牌行为,tensorflow,tensorflow-datasets,Tensorflow,Tensorflow Datasets,我发现reshuffle\u每次迭代参数到tf.Dataset.shuffle有点混乱。考虑这个代码: import tensorflow as tf flist = ["trimg1", "trimg2", "trimg3", "trimg4"] filenames = tf.constant(flist) train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames)) train_x_dataset = train_x
reshuffle\u每次迭代
参数到tf.Dataset.shuffle
有点混乱。考虑这个代码:
import tensorflow as tf
flist = ["trimg1", "trimg2", "trimg3", "trimg4"]
filenames = tf.constant(flist)
train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.shuffle(buffer_size=10, reshuffle_each_iteration=False)
it_train_x = train_x_dataset.make_initializable_iterator()
next_sample = it_train_x.get_next()
with tf.Session() as sess:
for epoch in range(3):
sess.run(it_train_x.initializer)
print("Starting epoch ", epoch)
while True:
try:
s = sess.run(next_sample)
print("Sample: ", s)
except tf.errors.OutOfRangeError:
break
代码输出:
Starting epoch 0
Sample: b'trimg2'
Sample: b'trimg1'
Sample: b'trimg3'
Sample: b'trimg4'
Starting epoch 1
Sample: b'trimg4'
Sample: b'trimg3'
Sample: b'trimg2'
Sample: b'trimg1'
Starting epoch 2
Sample: b'trimg3'
Sample: b'trimg2'
Sample: b'trimg4'
Sample: b'trimg1'
即使每次迭代的reshuffle\u
都是False
,tensorflow仍然会在数据集迭代后重新洗牌。是否有其他方法重置迭代器?每次迭代时,重新洗牌的预期行为是什么
我知道我可以修复种子
,每次都得到相同的顺序,问题是如何重新调整每个迭代
的工作
我还知道,更惯用的方法是使用repeat()
,但在我的例子中,每个历元的实际样本数都会不同。我怀疑TensorFlow仍然会在for循环的每次迭代中重新排列数据集,因为迭代器在每次迭代中都会初始化。每次初始化迭代器时,都会对数据集应用shuffle函数
预期的行为是迭代器初始化一次,reshuffle\u每次迭代
允许您选择是否在重复数据时重新洗牌(每次对原始数据进行迭代)
我不知道如何重新编写代码以处理可变数量的样本,但以下是使用repeat()
函数修改的代码,以证明我的声明:
flist = ["trimg1", "trimg2", "trimg3", "trimg4"]
filenames = tf.constant(flist)
train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.shuffle(buffer_size=10, reshuffle_each_iteration=False)
train_x_dataset = train_x_dataset.repeat(4)
it_train_x = train_x_dataset.make_initializable_iterator()
next_sample = it_train_x.get_next()
with tf.Session() as sess:
sess.run(it_train_x.initializer)
while True:
try:
s = sess.run(next_sample)
print("Sample: ", s)
except tf.errors.OutOfRangeError:
break
输出:
Sample: trimg1
Sample: trimg2
Sample: trimg4
Sample: trimg3
Sample: trimg1
Sample: trimg2
Sample: trimg4
Sample: trimg3
Sample: trimg1
Sample: trimg2
Sample: trimg4
Sample: trimg3
Sample: trimg1
Sample: trimg2
Sample: trimg4
Sample: trimg3
然而,如果我设置reshuffle\u each\u iteration=True
,我会得到:
Sample: trimg3
Sample: trimg1
Sample: trimg2
Sample: trimg4
Sample: trimg3
Sample: trimg2
Sample: trimg1
Sample: trimg4
Sample: trimg3
Sample: trimg1
Sample: trimg2
Sample: trimg4
Sample: trimg4
Sample: trimg1
Sample: trimg2
Sample: trimg3
希望这有帮助
编辑:我的主张的进一步证据:这些在TensorFlow代码库中。在本例中,使用一次迭代,因此只初始化一次。批量大小为10的数据用于大小为10的数据,因此每次调用迭代器.get_next()
都会遍历整个源数据。代码检查该函数的每个后续调用是否返回相同的(无序)数组
关于的讨论可能会进一步阐明不同迭代器的预期用途和预期行为,并可能帮助您找到特定问题的解决方案