Tensorflow 阻塞tf.contrib.stagingrea get()和put()操作
工作环境Tensorflow 阻塞tf.contrib.stagingrea get()和put()操作,tensorflow,tensorflow-gpu,Tensorflow,Tensorflow Gpu,工作环境 TensorFlow发布版本:1.3.0-rc2 TensorFlow git版本:v1.3.0-rc1-994-gb93fd37 操作系统:CentOS Linux 7.2.1511版(核心版) 问题场景 我正在使用TensorFlow StaginArea ops来提高输入管道的效率。下面是构建输入管道的代码片段的一部分: train_put_op_list = [] train_get_op_list = [] val_put_op_list = []
- TensorFlow发布版本:1.3.0-rc2
- TensorFlow git版本:v1.3.0-rc1-994-gb93fd37
- 操作系统:CentOS Linux 7.2.1511版(核心版)
train_put_op_list = []
train_get_op_list = []
val_put_op_list = []
val_get_op_list = []
with tf.variable_scope(tf.get_variable_scope()) as vscope:
for i in range(4):
with tf.device('/gpu:%d'%i):
with tf.name_scope('GPU-Tower-%d'%i) as scope:
trainstagingarea = tf.contrib.staging.StagingArea(dtypes=[tf.float32, tf.int32],
shapes=[[64, 221, 221, 3],[64]],
capacity=0)
valstagingarea = tf.contrib.staging.StagingArea(dtypes=[tf.float32, tf.int32],
shapes=[[128, 221, 221, 3],[128]],
capacity=0)
train_put_op_list.append(trainstagingarea.put(train_iterator.get_next()))
val_put_op_list.append(valstagingarea.put(val_iterator.get_next()))
train_get_op_list.append(trainstagingarea.get())
val_get_op_list.append(valstagingarea.get())
with tf.device('/cpu:0'):
worktype = tf.get_variable("wt",[], initializer=tf.zeros_initializer(), trainable=False)
workcondition = tf.equal(worktype, 1)
#elem = tf.cond(workcondition, lambda: train_iterator.get_next(), lambda: val_iterator.get_next())
elem = tf.cond(workcondition, lambda: train_get_op_list[i], lambda: val_get_op_list[i])
# This is followed by the network construction and optimizer
现在在执行时,我首先运行几次put()
ops,然后继续运行迭代。如下图所示:
with tf.Session(config=config) as sess:
sess.run(init_op)
sess.run(iterator_training_op)
sess.run(iterator_validation_op)
sess.run(tf.assign(worktype, 0))
for i in range(4):
sess.run(train_put_op_list)
sess.run(val_put_op_list)
writer = tf.summary.FileWriter('.', graph=tf.get_default_graph())
epoch = 0
iter = 0
previous = 0
while(epoch<10):
try:
if(PROCESSINGTYPE is 'validation'):
sess.run(val_put_op_list)
[val_accu, summaries, numsamp] = sess.run([running_accuracy, validation_summary_op, processed])
previous+=numsamp
print("Running Accuracy = {} : Number of sample processed = {} ".format(val_accu, previous))
else:
sess.run(train_put_op_list)
[loss_value, _, train_accu, summaries, batch_accu, numsamp] = sess.run([total_loss, apply_gradient_op, running_accuracy, training_summary_op, batch_accuracy, pr\
ocessed])
#Remaining part of the code (not important for question)
您可以注意到,在作为sess:的tf.Session()开始时,get()
和put()
操作运行了4次。输出也限制为4行。这就是说,,
sess.run(val_put_op_list)
在中运行,而循环不执行任何操作。因此,当sess.run(running\u accurity).
调用get()
时,stagingara
在4行之后被发现为空,因此发生阻塞
- 我对这个问题的分析正确吗
- 在这里使用
get()
和put()
操作的正确方法是什么
- 如果
stagingara
已满且put()
被阻止,这是否也会阻止整个代码?TensorFlow文档并没有对此做任何说明
看一看。这解决了一些死锁,可能会进入1.4.0。免责声明:我不是tensorflower。这就像一个常规队列——在空台上使用“get”或在满台上使用“put”将挂起会话。运行。你看到这个用法的例子了吗?请注意,它有额外的逻辑来启动数据集的队列运行程序,但这里没有发生任何事情。每个暂存区的容量为5
。在开始时,运行4个put()
op,然后在循环中运行一个put
op。然后运行一个get()
op,然后运行另一个put
op。此外,如果你仔细阅读我的问题并研究输出,你会看到我的问题哦,我没有研究代码,这只是关于暂存区已满时会发生什么的最后一个问题。你有什么想法吗?也许你可以更好地隔离一个问题(即使用tf.Print)不幸的是,这显然还没有进入1.4.1版。看起来这已经添加到1.5.0-rc0中了。
# Validation is done first and the following is the output
Running Accuracy = 0.0 : Number of sample processed = 512
Running Accuracy = 0.00390625 : Number of sample processed = 1024
Running Accuracy = 0.0 : Number of sample processed = 1536
Running Accuracy = 0.001953125 : Number of sample processed = 2048
# The code hangs here