更新tensorflow中的变量值

更新tensorflow中的变量值,tensorflow,Tensorflow,我有一个关于通过tensorflow python api更新张量值的基本问题 考虑代码片段: x = tf.placeholder(shape=(None,10), ... ) y = tf.placeholder(shape=(None,), ... ) W = tf.Variable( randn(10,10), dtype=tf.float32 ) yhat = tf.matmul(x, W) 现在让我们假设我想要实现某种迭代更新W值的算法(例如,一些优化算法)。这将包括以下步骤: f

我有一个关于通过tensorflow python api更新张量值的基本问题

考虑代码片段:

x = tf.placeholder(shape=(None,10), ... )
y = tf.placeholder(shape=(None,), ... )
W = tf.Variable( randn(10,10), dtype=tf.float32 )
yhat = tf.matmul(x, W)
现在让我们假设我想要实现某种迭代更新W值的算法(例如,一些优化算法)。这将包括以下步骤:

for i in range(max_its):
     resid = y_hat - y
     W = f(W , resid) # some update 
这里的问题是,LHS上的
W
是一个新的张量,而不是
W
中使用的
yhat=tf.matmul(x,W)
!也就是说,将创建一个新变量,并且在我的“模型”中使用的
W
的值不会更新

现在解决这个问题的一个办法是

 for i in range(max_its):
     resid = y_hat - y
     W = f(W , resid) # some update 
     yhat = tf.matmul( x, W)
这将导致为我的循环的每个迭代创建一个新的“模型”


有没有更好的方法(在python中)来实现这一点,而不必为循环的每次迭代创建一大堆新模型——而是更新原始的张量
W
“就地”可以说?

变量有一个赋值方法。试试看:
W.assign(f(W,resid))
@aarbelle简洁的回答是正确的,我会把它展开一点,以防有人需要更多信息。下面最后两行用于更新W

x = tf.placeholder(shape=(None,10), ... )
y = tf.placeholder(shape=(None,), ... )
W = tf.Variable(randn(10,10), dtype=tf.float32 )
yhat = tf.matmul(x, W)

...

for i in range(max_its):
    resid = y_hat - y
    update = W.assign(f(W , resid)) # do not forget to initialize tf variables. 
    # "update" above is just a tf op, you need to run the op to update W.
    sess.run(update)

准确地说,答案应该是sess.run(W.assign(f(W,resid))。然后使用
sess.run(W)
显示更改