Python 如何在TensorFlow循环中赋值

Python 如何在TensorFlow循环中赋值,python,filter,tensorflow,while-loop,time-series,Python,Filter,Tensorflow,While Loop,Time Series,我们的目标是在TensorFlow中实现一个递归函数,以随时间过滤信号 输入随后以5-D张量形式呈现[批次、深度、高度、宽度、通道]。我想使用来在深度上迭代,,并根据以前时间步长的值重新分配值。但是,我无法在循环中重新分配变量值 为了简化问题,我创建了问题的一维版本: def condition(i, signal): return tf.less(i, signal.shape[0]) def operation(i, signal): signal = tf.get_var

我们的目标是在TensorFlow中实现一个递归函数,以随时间过滤信号

输入
随后以5-D张量形式呈现
[批次、深度、高度、宽度、通道]
。我想使用来在深度上迭代
,并根据以前时间步长的值重新分配值。但是,我无法在循环中重新分配变量值

为了简化问题,我创建了问题的一维版本:

def condition(i, signal):
    return tf.less(i, signal.shape[0])

def operation(i, signal):
    signal = tf.get_variable("signal")
    signal = signal[i].assign(signal[i-1]*2)
    i = tf.add(i, 1)
    return (i, signal)

with tf.variable_scope("scope"):
    i = tf.constant(1)
    init = tf.constant_initializer(0)
    signal = tf.get_variable("scope", [4], tf.float32, init, trainable = False)
    signal = tf.assign(signal[0], 1.2)

with tf.variable_scope("scope", reuse = True):
    loops_vars = [i, signal]
    i, signal = tf.while_loop(condition, operation, loop_vars, back_prop = False)

with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    i, signal = session.run([i, signal])
返回必须在会话中运行才能计算()的操作

我预计TensorFlow会在循环中链接操作,因此在我运行会话并请求
信号
后执行分配。但是,当我执行给定代码并打印结果时,
信号
包含
[1.2,0,0,0]
I
包含(如预期的那样)
4


这里我的误解是什么?我如何更改代码,以便重新分配
信号的值?

虽然循环变量仅通过body函数的返回值更新,但不应使用自己的分配操作。相反,您需要返回希望
signal
在循环后具有的值,就像
i
一样

另外,您不应该在body或condition中使用
tf.get_variable
,只需使用您收到的参数即可

# ...
def operation(i, signal):
    shape = signal.shape
    signal = tf.concat([signal[:i], [signal[i - 1] * 2], signal[i + 1:]], axis=0)
    signal.set_shape(shape) # Shapes have to be invariant for the loop
    i = tf.add(i, 1)
    return (i, signal)

with tf.variable_scope("scope"):
    i = tf.constant(1)
    init = tf.constant_initializer(1.2) # init signal here and avoid tf.assign
    signal = tf.get_variable("scope", [4], tf.float32, init, trainable = False)
    # signal = tf.assign(signal[0], 1.2) 

# ...

我很确定您正在创建
信号的新副本。查看并检查
tf.get_variable\u scope()。重用
,也可以轻松跳过将
信号
传递给所有while循环函数,因为您可以访问
信号
的任何地方。@NULL否,不创建副本。变量作用域设置为在创建while\u循环之前重用