Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python Tensor Flow shuffle_batch()块在历元结束时_Python_Tensorflow - Fatal编程技术网

Python Tensor Flow shuffle_batch()块在历元结束时

Python Tensor Flow shuffle_batch()块在历元结束时,python,tensorflow,Python,Tensorflow,我正在使用tf.train.shuffle\u batch()创建成批的输入图像。它包括一个min_after_dequeue参数,该参数确保内部队列中有指定数量的元素,如果没有,则阻止所有其他元素 images, label_batch = tf.train.shuffle_batch( [image, label], batch_size=FLAGS.batch_size, num_threads=num_preprocess_threads, capacity=FLAGS.

我正在使用tf.train.shuffle\u batch()创建成批的输入图像。它包括一个min_after_dequeue参数,该参数确保内部队列中有指定数量的元素,如果没有,则阻止所有其他元素

images, label_batch = tf.train.shuffle_batch(
  [image, label],
  batch_size=FLAGS.batch_size,
  num_threads=num_preprocess_threads,
  capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
  min_after_dequeue=FLAGS.min_queue_size)
在一个时代结束时,当我进行评估时(我确信这在训练中也是正确的,但我没有测试过),一切都会阻塞。我发现,这是在同一时刻,内部无序批处理队列在退出队列元素后将剩下不到min的元素。在程序中的这个时候,我理想地只想将剩余的元素排出来,但我不确定如何排


显然,当您知道没有更多的元素可以使用.close()方法排队时,TF队列中的这种类型的阻塞可以被关闭。但是,由于基础队列隐藏在函数中,如何调用该方法?

您可以正确地认为,当队列中的出列后
min\u元素少于
min\u时,运行该操作将停止出列线程阻塞

该函数创建一个在后台线程中对队列执行操作的。如果按如下方式启动,并通过,您将能够干净地关闭队列(基于示例):


下面是我最终开始工作的代码,尽管有一大堆警告说我排队的元素被取消了

lv = tf.constant(label_list)

label_fifo = tf.FIFOQueue(len(filenames),tf.int32,shapes=[[]])
# if eval_data:
    # num_epochs = 1
# else:
    # num_epochs = None
file_fifo = tf.train.string_input_producer(filenames, shuffle=False, capacity=len(filenames))
label_enqueue = label_fifo.enqueue_many([lv])


reader = tf.WholeFileReader()
result.key, value = reader.read(file_fifo)
image = tf.image.decode_jpeg(value, channels=3)
image.set_shape([128,128,3])
result.uint8image = image
result.label = label_fifo.dequeue()

images, label_batch = tf.train.shuffle_batch(
  [result.uint8image, result.label],
  batch_size=FLAGS.batch_size,
  num_threads=num_preprocess_threads,
  capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
  min_after_dequeue=FLAGS.min_queue_size)

#in eval file:
label_enqueue, images, labels = load_input.inputs()
#restore from checkpoint in between
coord = tf.train.Coordinator()
try:
  threads = []
  for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
    threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                     start=True))

  num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
  true_count = 0  # Counts the number of correct predictions.
  total_sample_count = num_iter * FLAGS.batch_size

  sess.run(label_enqueue)
  step = 0
  while step < num_iter and not coord.should_stop():
    end_epoch = False
    if step > 0:
        for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
            #check if not enough elements in queue
            size = qr._queue.size().eval()
            if size - FLAGS.batch_size < FLAGS.min_queue_size:
                end_epoch = True
    if end_epoch:
        #enqueue more so that we can finish
        sess.run(label_enqueue)
    #actually run step
    predictions = sess.run([top_k_op])
lv=tf.常数(标签列表)
label_fifo=tf.FIFOQueue(len(文件名),tf.int32,shapes=[[]))
#如果评估数据:
#num_epochs=1
#其他:
#num_epochs=无
file\u fifo=tf.train.string\u input\u producer(文件名,shuffle=False,capacity=len(文件名))
label\u enqueue=label\u fifo.enqueue\u many([lv])
reader=tf.WholeFileReader()
result.key,value=reader.read(文件\u fifo)
image=tf.image.decode_jpeg(值,通道=3)
image.set_形状([128128,3])
result.uint8image=image
result.label=label\u fifo.dequeue()
图像,标签\u批=tf.train.shuffle\u批(
[result.uint8图像,result.label],
批次大小=标志。批次大小,
num_threads=num_preprocess_threads,
容量=FLAGS.min\u队列大小+3*FLAGS.batch\u大小,
退出队列后的最小队列=标志。最小队列大小)
#在eval文件中:
标签\排队,图像,标签=加载\输入。输入()
#从中间的检查点恢复
coord=tf.train.Coordinator()
尝试:
线程=[]
对于tf.get_集合(tf.GraphKeys.QUEUE_runner)中的qr:
扩展(qr.create_)线程(sess,coord=coord,daemon=True,
开始=真)
num_iter=int(math.ceil(FLAGS.num_示例/FLAGS.batch_大小))
true_count=0#统计正确预测的数量。
总样本计数=数量*FLAGS.batch\u大小
sess.run(标签排队)
步长=0
而步骤0:
对于tf.get_集合(tf.GraphKeys.QUEUE_runner)中的qr:
#检查队列中是否没有足够的元素
size=qr.\u queue.size().eval()
如果大小-FLAGS.batch\u size
有一个可选参数允许\u较小的\u最终\u批次


“allow_minger_final_batch:(可选)布尔值。如果为True,则如果队列中剩余的项目不足,则允许最终批次更小。”

这是我最初的想法,但它不起作用。为了让它完成一个历元,我被迫在最后显式地将更多元素排队。我认为问题可能在于,我对标签(整数)使用FIFO,对文件名使用string_input_producer,然后将从string_input_producer指定的文件加载的图像发送到shuffle_批处理中。我将在上面的问题中添加代码。您是否可以将有效的代码移到一个答案中并接受它,以便让其他人更容易看到您的问题确实有答案,而不必阅读整个问题。谢谢
lv = tf.constant(label_list)

label_fifo = tf.FIFOQueue(len(filenames),tf.int32,shapes=[[]])
# if eval_data:
    # num_epochs = 1
# else:
    # num_epochs = None
file_fifo = tf.train.string_input_producer(filenames, shuffle=False, capacity=len(filenames))
label_enqueue = label_fifo.enqueue_many([lv])


reader = tf.WholeFileReader()
result.key, value = reader.read(file_fifo)
image = tf.image.decode_jpeg(value, channels=3)
image.set_shape([128,128,3])
result.uint8image = image
result.label = label_fifo.dequeue()

images, label_batch = tf.train.shuffle_batch(
  [result.uint8image, result.label],
  batch_size=FLAGS.batch_size,
  num_threads=num_preprocess_threads,
  capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
  min_after_dequeue=FLAGS.min_queue_size)

#in eval file:
label_enqueue, images, labels = load_input.inputs()
#restore from checkpoint in between
coord = tf.train.Coordinator()
try:
  threads = []
  for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
    threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                     start=True))

  num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
  true_count = 0  # Counts the number of correct predictions.
  total_sample_count = num_iter * FLAGS.batch_size

  sess.run(label_enqueue)
  step = 0
  while step < num_iter and not coord.should_stop():
    end_epoch = False
    if step > 0:
        for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
            #check if not enough elements in queue
            size = qr._queue.size().eval()
            if size - FLAGS.batch_size < FLAGS.min_queue_size:
                end_epoch = True
    if end_epoch:
        #enqueue more so that we can finish
        sess.run(label_enqueue)
    #actually run step
    predictions = sess.run([top_k_op])