Tensorflow 阻塞tf.contrib.stagingrea get()和put()操作

Tensorflow 阻塞tf.contrib.stagingrea get()和put()操作,tensorflow,tensorflow-gpu,Tensorflow,Tensorflow Gpu,工作环境 TensorFlow发布版本:1.3.0-rc2 TensorFlow git版本:v1.3.0-rc1-994-gb93fd37 操作系统:CentOS Linux 7.2.1511版(核心版) 问题场景 我正在使用TensorFlow StaginArea ops来提高输入管道的效率。下面是构建输入管道的代码片段的一部分: train_put_op_list = [] train_get_op_list = [] val_put_op_list = []

工作环境

  • TensorFlow发布版本:1.3.0-rc2
  • TensorFlow git版本:v1.3.0-rc1-994-gb93fd37
  • 操作系统:CentOS Linux 7.2.1511版(核心版)
问题场景

我正在使用TensorFlow StaginArea ops来提高输入管道的效率。下面是构建输入管道的代码片段的一部分:

  train_put_op_list = []
    train_get_op_list = []
    val_put_op_list = []
    val_get_op_list = []
    with tf.variable_scope(tf.get_variable_scope()) as vscope:
        for i in range(4):
            with tf.device('/gpu:%d'%i):
                with tf.name_scope('GPU-Tower-%d'%i) as scope:
                    trainstagingarea = tf.contrib.staging.StagingArea(dtypes=[tf.float32, tf.int32],
                                                                 shapes=[[64, 221, 221, 3],[64]],
                                                                      capacity=0)
                    valstagingarea = tf.contrib.staging.StagingArea(dtypes=[tf.float32, tf.int32],
                                                                      shapes=[[128, 221, 221, 3],[128]],
                                                                      capacity=0)
                    train_put_op_list.append(trainstagingarea.put(train_iterator.get_next()))
                    val_put_op_list.append(valstagingarea.put(val_iterator.get_next()))
                    train_get_op_list.append(trainstagingarea.get())
                    val_get_op_list.append(valstagingarea.get())
                    with tf.device('/cpu:0'):
                        worktype = tf.get_variable("wt",[], initializer=tf.zeros_initializer(), trainable=False)
                    workcondition = tf.equal(worktype, 1)
                    #elem = tf.cond(workcondition, lambda: train_iterator.get_next(), lambda: val_iterator.get_next())
                    elem = tf.cond(workcondition, lambda: train_get_op_list[i], lambda: val_get_op_list[i])
                    # This is followed by the network construction and optimizer 
现在在执行时,我首先运行几次
put()
ops,然后继续运行迭代。如下图所示:

with tf.Session(config=config) as sess:
        sess.run(init_op)
        sess.run(iterator_training_op)
        sess.run(iterator_validation_op)
        sess.run(tf.assign(worktype, 0))
        for i in range(4):
            sess.run(train_put_op_list)
            sess.run(val_put_op_list)
        writer = tf.summary.FileWriter('.', graph=tf.get_default_graph())
        epoch = 0
        iter = 0
        previous = 0
        while(epoch<10):
            try:
                if(PROCESSINGTYPE is 'validation'):
                    sess.run(val_put_op_list)
                    [val_accu, summaries, numsamp] = sess.run([running_accuracy, validation_summary_op, processed])
                    previous+=numsamp
                    print("Running Accuracy = {} : Number of sample processed = {} ".format(val_accu, previous))
                else:
                    sess.run(train_put_op_list)
                    [loss_value, _, train_accu, summaries, batch_accu, numsamp] = sess.run([total_loss, apply_gradient_op, running_accuracy, training_summary_op, batch_accuracy, pr\
ocessed])
                    #Remaining part of the code (not important for question)
您可以注意到,在作为sess:的
tf.Session()开始时,
get()
put()
操作运行了
4次。输出也限制为4行。这就是说,,
sess.run(val_put_op_list)
中运行,而
循环不执行任何操作。因此,当
sess.run(running\u accurity).
调用
get()
时,
stagingara
4行之后被发现为空,因此发生阻塞

  • 我对这个问题的分析正确吗
  • 在这里使用
    get()
    put()
    操作的正确方法是什么
  • 如果
    stagingara
    已满且
    put()
    被阻止,这是否也会阻止整个代码?TensorFlow文档并没有对此做任何说明

看一看。这解决了一些死锁,可能会进入1.4.0。免责声明:我不是tensorflower。

这就像一个常规队列——在空台上使用“get”或在满台上使用“put”将挂起会话。运行。你看到这个用法的例子了吗?请注意,它有额外的逻辑来启动数据集的队列运行程序,但这里没有发生任何事情。每个暂存区的容量为
5
。在开始时,运行4个
put()
op,然后在循环中运行一个
put
op。然后运行一个
get()
op,然后运行另一个
put
op。此外,如果你仔细阅读我的问题并研究输出,你会看到我的问题哦,我没有研究代码,这只是关于暂存区已满时会发生什么的最后一个问题。你有什么想法吗?也许你可以更好地隔离一个问题(即使用tf.Print)不幸的是,这显然还没有进入1.4.1版。看起来这已经添加到1.5.0-rc0中了。
# Validation is done first and the following is the output
Running Accuracy = 0.0 : Number of sample processed = 512
Running Accuracy = 0.00390625 : Number of sample processed = 1024
Running Accuracy = 0.0 : Number of sample processed = 1536
Running Accuracy = 0.001953125 : Number of sample processed = 2048
# The code hangs here