Python 使用输入队列的Tensorflow训练被卡住
我正在尝试建立一个类似于教程中的NN培训 我的代码如下所示:Python 使用输入队列的Tensorflow训练被卡住,python,multithreading,queue,tensorflow,Python,Multithreading,Queue,Tensorflow,我正在尝试建立一个类似于教程中的NN培训 我的代码如下所示: def train(): init_op = tf.initialize_all_variables() sess = tf.Session() sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) step = 0
def train():
init_op = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
step = 0
try:
while not coord.should_stop():
step += 1
print 'Training step %i' % step
training = train_op()
sess.run(training)
except tf.errors.OutOfRangeError:
print 'Done training - epoch limit reached.'
finally:
coord.request_stop()
coord.join(threads)
sess.close()
与
教程说
这些要求您在运行任何训练或推理步骤之前调用tf.train.start\u queue\u runners
,否则它将永远挂起
。我正在调用tf.train.start\u queue\u runner
,但是train()
的执行在第一次出现sess.run(training)
时仍然会被卡住
有人知道我做错了什么吗?每次尝试运行训练循环时,您都在重新定义网络 请记住,TensorFlow定义了一个执行图,然后执行它。您想在运行循环之外调用train_op(),在调用
初始化所有变量和tf.train.start_queue_runner
MIN_NUM_EXAMPLES_IN_QUEUE = 10
NUM_PRODUCING_THREADS = 1
NUM_CONSUMING_THREADS = 1
def train_op():
images, true_labels = inputs()
predictions = NET(images)
true_labels = tf.cast(true_labels, tf.float32)
loss = tf.nn.softmax_cross_entropy_with_logits(predictions, true_labels)
return OPTIMIZER.minimize(loss)
def inputs():
filenames = [os.path.join(FLAGS.train_dir, filename)
for filename in os.listdir(FLAGS.train_dir)
if os.path.isfile(os.path.join(FLAGS.train_dir, filename))]
filename_queue = tf.train.string_input_producer(filenames,
num_epochs=FLAGS.training_epochs, shuffle=True)
example_list = [_read_and_preprocess_image(filename_queue)
for _ in xrange(NUM_CONSUMING_THREADS)]
image_batch, label_batch = tf.train.shuffle_batch_join(
example_list,
batch_size=FLAGS.batch_size,
capacity=MIN_NUM_EXAMPLES_IN_QUEUE + (NUM_CONSUMING_THREADS + 2) * FLAGS.batch_size,
min_after_dequeue=MIN_NUM_EXAMPLES_IN_QUEUE)
return image_batch, label_batch