将tensorflow Defun视为闭包

将tensorflow Defun视为闭包,tensorflow,Tensorflow,我在tensorflow中使用Defun decorator时遇到问题。也就是说,Defun不能关闭在外部创建的任何TF ops。下面是一个独立的示例,展示了我想要做的事情。请注意,张量x属于自定义_op调用内外的不同图形。Defun代码创建一个临时图形,将该图形转换为函数proto,然后将其合并到原始图形中。代码在第一步崩溃,因为我们结束的张量不在新的临时图中。有办法解决这个问题吗?能够结束事情会非常有帮助 import tensorflow as tf from tensor

我在tensorflow中使用Defun decorator时遇到问题。也就是说,Defun不能关闭在外部创建的任何TF ops。下面是一个独立的示例,展示了我想要做的事情。请注意,张量x属于自定义_op调用内外的不同图形。Defun代码创建一个临时图形,将该图形转换为函数proto,然后将其合并到原始图形中。代码在第一步崩溃,因为我们结束的张量不在新的临时图中。有办法解决这个问题吗?能够结束事情会非常有帮助

    import tensorflow as tf
    from tensorflow.python.framework import function

    w = tf.Variable(1.0)
    function_factory = lambda x: x*w

    @function.Defun(x=tf.float32)

    def custom_op(x):
        print('graph for x inside custom_op: ', x.graph)
        return function_factory(x)

    x = tf.constant(2.0)

    print('graph for x outside custom_op: ', x.graph)
    y = custom_op(x)

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        sess.run(y)

否,
Defun
decorator不会捕获所有内容。您需要显式地传入
w
,如下例所示:

import tensorflow as tf
from tensorflow.python.framework import function

w = tf.Variable(1.0)

@function.Defun(tf.float32, tf.float32)
def custom_op(x, w):
    print('graph for x inside custom_op: ', x.graph)
    return x * w

x = tf.constant(2.0)
print('graph for x outside custom_op: ', x.graph)
y = custom_op(x, tf.identity(w))

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    sess.run(y)
(如果需要,我们可以添加更多完整的捕获支持。)