Python tf.py_func()使用lambda函数在循环中意外输出

Python tf.py_func()使用lambda函数在循环中意外输出,python,tensorflow,Python,Tensorflow,在实现相同的简单功能时,我观察到Numpy和pure Tensorflow的不同行为,这些功能在for循环中的迭代中共享一个变量 让我们从纯Numpy版本开始: def my_func(x, k): return np.tile(x,k) x = np.ones((1), np.int64) for i in range(1,3): x = my_func(x, i) print(x) 这将产生预期的输出。最初x是[1]。在第一次迭代中,它被复制一次以生成[1]。然后在下一

在实现相同的简单功能时,我观察到Numpy和pure Tensorflow的不同行为,这些功能在for循环中的迭代中共享一个变量

让我们从纯Numpy版本开始:

def my_func(x, k):
    return np.tile(x,k)

x = np.ones((1), np.int64)
for i in range(1,3):
    x = my_func(x, i)

print(x)
这将产生预期的输出。最初
x
[1]
。在第一次迭代中,它被复制一次以生成
[1]
。然后在下一次迭代中,将结果复制两次,生成最终输出
[1]

类似的方法也会在纯Tensorflow中产生相同的预期输出:

x = tf.constant([1], tf.int64)
for i in range(1,3):
    x = tf.tile(x, [i])

with tf.Session() as sess:
    xx = sess.run(x)
    print(xx)
输出为
[1]

现在我正尝试使用基本相同的东西,我无法理解为什么我看到了不同的输出。此代码:

import tensorflow as tf
import numpy as np

def my_func(x, k):
    return np.tile(x,k)

x = tf.constant([1], tf.int64)
for i in range(1,3):
    x = tf.py_func(lambda y: my_func(y, i), [x], tf.int64)

with tf.Session() as sess:
    xx = sess.run(x)
    print(xx)
生成意外结果
[1]

为什么会这样?
py_func
是否具有某些属性,使其无法与共享(张量)变量名(在本例中是在每次循环迭代中更新的变量
x
)配合使用


请注意,这是一个简单的例子,重现了这个问题,其功能很容易在纯Tensorflow中重现。在我的实际应用程序中,需要使用
tf.py_func
,因为该功能更加复杂。

如果没有lambda函数,它将按预期工作:

import tensorflow as tf
import numpy as np

def my_func(x, k):
    return np.tile(x,k)

x = tf.constant([1], tf.int64)
for i in range(1,3):
    x = tf.py_func(my_func, [x, i], tf.int64)

with tf.Session() as sess:
    xx = sess.run(x)
    print(xx)
返回
[1]

编辑

我找到了原因:
lambda y:my_func(y,I)
通过引用而不是值保存I。因此,for循环的最后一个i值应用于循环中的所有
py_func
。下面是一个简单的示例,说明了问题:

import tensorflow as tf

def my_func(x, y):
  return x - y

x1 = tf.constant([0], tf.float32)
for i in range(2):
    x1 = tf.py_func(lambda y: my_func(y, i), [x1], tf.float32)

x2 = tf.constant([0], tf.float32)
x2 = tf.py_func(lambda y: my_func(y, 0), [x2], tf.float32)
x2 = tf.py_func(lambda y: my_func(y, 1), [x2], tf.float32)

with tf.Session() as sess:
    print(sess.run(x1))
    print(sess.run(x2))

干得好!这一次真让我抓狂。