Python 为什么我的数据集没有';即使我设置了数据集,也不会停止。重复(1)

Python 为什么我的数据集没有';即使我设置了数据集,也不会停止。重复(1),python,tensorflow,iterator,dataset,Python,Tensorflow,Iterator,Dataset,我有一个训练数据集和一个测试数据集 #training dataset dataset_train = tf.data.TFRecordDataset(files_train) dataset_train = dataset_train.map(...) dataset_train = dataset_train.shuffle(...) dataset_train = dataset_train.batch(...) dataset_train = dataset_train.repeat(1

我有一个训练数据集和一个测试数据集

#training dataset
dataset_train = tf.data.TFRecordDataset(files_train)
dataset_train = dataset_train.map(...)
dataset_train = dataset_train.shuffle(...)
dataset_train = dataset_train.batch(...)
dataset_train = dataset_train.repeat(1)
iterator_train = dataset_train.make_initializable_iterator()

#test dataset
dataset_test = tf.data.TFRecordDataset(files_test)
dataset_test = dataset_test.map(...)
dataset_test = dataset_test.shuffle(...)
dataset_test = dataset_test.batch(...)
dataset_test = dataset_test.repeat(...)
iterator_test = dataset_test.make_initializable_iterator()

#for switch between two datasets.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, dataset_train.output_types, dataset_train.output_shapes)
image_batch, label_batch = iterator.get_next()
在会议期间,我有:

# in tf.Session()
train_iterator_handle = sess.run(train_iterator.string_handle())
val_iterator_handle = sess.run(test_iterator.string_handle())
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

#start training, switch to training dataset
sess.run(iterator_train.initializer) 
while True:
    try:
        sess.run([train_step, ...])

        if global_step % N == 0: # test
            #start test, switch to test dataset
            sess.run(iterator_test.initializer)
            while True:
                try:
                    sess.run([acc_update, ...])
                except tf.errors.OutOfRangeError:
                    print("test finished")
                    break
            #test finished, switch back to training dataset
            sess.run(iterator_train.initializer) 
    except tf.errors.OutOfRangeError:
        print("training finished")
        break
我从TF的API中读到,训练数据集迭代器可以从上次离开的地方继续,,我认为训练数据集应该在迭代所有数据时停止,因为我使用:

dataset_train = dataset_train.repeat(1)
但实际上,我的程序运行并且没有停止。
所以我想我一定是在什么地方犯了一个严重的错误。有人能帮我吗

验证后的这一行
sess.run(迭代器\u train.initializer)
将重置您的列车生成器状态,因此它将从一开始就继续读取。我想,
N
比train迭代器中的步骤少,所以它不会停止

若您只是想在验证后继续训练,请不要再次调用训练迭代器初始值设定项