Tensorflow 排队时张量的强制复制

Tensorflow 排队时张量的强制复制,tensorflow,Tensorflow,首先,我不确定这个标题是否很好,但考虑到我对形势的理解,这是我能想到的最好的标题 背景是,我试图理解tensorflow中队列是如何工作的,遇到了以下问题,这让我感到困惑 我有一个变量n,我把它排队到tf.FIFOQueue,然后我增加这个变量。这会重复几次,人们会期望得到类似于0,1,2。。。但是,清空队列时,所有值都相同 更准确地说,代码如下: from __future__ import print_function import tensorflow as tf q = tf.FIF

首先,我不确定这个标题是否很好,但考虑到我对形势的理解,这是我能想到的最好的标题

背景是,我试图理解tensorflow中队列是如何工作的,遇到了以下问题,这让我感到困惑

我有一个变量n,我把它排队到tf.FIFOQueue,然后我增加这个变量。这会重复几次,人们会期望得到类似于0,1,2。。。但是,清空队列时,所有值都相同

更准确地说,代码如下:

from __future__ import print_function

import tensorflow as tf

q = tf.FIFOQueue(10, tf.float32)

n = tf.Variable(0, trainable=False, dtype=tf.float32)
inc = n.assign(n+1)
enqueue = q.enqueue(n)

init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)

sess.run(enqueue)
sess.run(inc)

sess.run(enqueue)
sess.run(inc)

sess.run(enqueue)
sess.run(inc)

print(sess.run(q.dequeue()))
print(sess.run(q.dequeue()))
print(sess.run(q.dequeue()))
我希望它能打印出来:

0.0
1.0
2.0
相反,我得到了以下结果:

3.0
3.0
3.0
我好像在把指向n的指针推到队列中,而不是我想要的实际值。然而,我对tensorflow的内部结构并没有任何实际的理解,所以可能还有其他的事情发生了

我试着换衣服

enqueue = q.enqueue(n)

因为答案和给我的印象是它可能会有帮助,但它不会改变结果。我还尝试添加了一个tf.control_dependencies(),但再次说明,当退出队列时,所有值都是相同的


编辑:上面的输出来自于在单CPU计算机上运行代码,当尝试查看tensorflow的不同版本之间是否存在差异时,我注意到如果在带有CPU和GPU的计算机上运行代码,我会得到“预期”结果。事实上,如果我使用CUDA_VISIBLE_DEVICES=“”运行,我会得到上面的结果,而使用CUDA_VISIBLE_DEVICES=“0”运行,我会得到“预期”结果。

要强制非缓存读取,您可以这样做

q.enqueue(tf.add(q, 0))
这是批处理规范化层强制复制的功能

变量如何读取和引用的语义正在修改过程中,因此它们暂时不直观。特别是,我希望
q.enqueue(v.read_value())
强制非缓存读取,但它不能修复TF 0.12rc1上的示例


使用GPU机器将变量放在GPU上,而队列仅为CPU,因此
enqueue
op强制执行GPU->CPU拷贝。

强制执行非缓存读取,您可以执行以下操作

q.enqueue(tf.add(q, 0))
这是批处理规范化层强制复制的功能

变量如何读取和引用的语义正在修改过程中,因此它们暂时不直观。特别是,我希望
q.enqueue(v.read_value())
强制非缓存读取,但它不能修复TF 0.12rc1上的示例


使用GPU机器将变量放在GPU上,而队列仅为CPU,因此
enqueue
op强制GPU->CPU复制。

如果有帮助,我发现其他答案尽管正确,但并不适用于所有数据类型

例如,这适用于浮点或整数,但在n是字符串张量时失败:

q.enqueue(tf.add(n, 0))
当队列使用具有异构类型(例如Int和Float)的元组时,此操作失败:

因此,如果您发现自己陷入了上述任何一种情况,请尝试以下方法:

q.enqueue(tf.add(n, tf.zeros_like(n)))
或者,要将元组t排队:

这甚至适用于字符串张量和异构元组类型

希望有帮助

--


更新:tf.bool类型似乎不能与tf.zeros_like()一起使用。对于这些类型,可能需要显式转换为整数类型。

如果有帮助,我发现其他答案尽管正确,但并不适用于所有数据类型

例如,这适用于浮点或整数,但在n是字符串张量时失败:

q.enqueue(tf.add(n, 0))
当队列使用具有异构类型(例如Int和Float)的元组时,此操作失败:

因此,如果您发现自己陷入了上述任何一种情况,请尝试以下方法:

q.enqueue(tf.add(n, tf.zeros_like(n)))
或者,要将元组t排队:

这甚至适用于字符串张量和异构元组类型

希望有帮助

--


更新:tf.bool类型似乎不能与tf.zeros_like()一起使用。对于这些,可能需要显式转换为整数类型。

另一个解决方法是
q.enqueue\u many([[n]])
而不是
q.enqueue(n)
,它将按值而不是参照排队另一个解决方法是
q.enqueue\u many([[n]])
而不是
q.enqueue(n)
,这将按值而不是按引用排队