Python tensorflow中的优化器如何访问在单独函数中创建的变量

Python tensorflow中的优化器如何访问在单独函数中创建的变量,python,tensorflow,namespaces,Python,Tensorflow,Namespaces,代码中感兴趣的行后面跟着多个散列(#)符号 为了理解这个目的,我在tensorflow中运行一个简单的线性回归。我使用的代码是: def generate_dataset(): #y = 2x+e where is the normally distributed error x_batch = np.linspace(-1,1,101) y_batch = 2*x_batch +np.random.random(*x_batch.shape)*0.3 return x_batch, y_bat

代码中感兴趣的行后面跟着多个散列(#)符号

为了理解这个目的,我在tensorflow中运行一个简单的线性回归。我使用的代码是:

def generate_dataset():
#y = 2x+e where is the normally distributed error
x_batch = np.linspace(-1,1,101)
y_batch = 2*x_batch +np.random.random(*x_batch.shape)*0.3
return x_batch, y_batch

def linear_regression():   ##################
x = tf.placeholder(tf.float32, shape = (None,), name = 'x')
y = tf.placeholder(tf.float32, shape = (None,), name = 'y')
with tf.variable_scope('lreg') as scope: ################
    w = tf.Variable(np.random.normal()) ##################
    y_pred = tf.multiply(w,x)
    loss = tf.reduce_mean(tf.square(y_pred - y))
return x,y, y_pred, loss
def run():
x_batch, y_batch = generate_dataset()
x, y, y_pred, loss = linear_regression()
optimizer = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

init = tf.global_variables_initializer()
with tf.Session() as session:
    session.run(init) 
    feed_dict = {x: x_batch, y: y_batch}
    for _ in range(30):
        loss_val, _ = session.run([loss, optimizer], feed_dict)
        print('loss:', loss_val.mean())
    y_pred_batch = session.run(y_pred, {x:x_batch})

    print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) ############
    print(session.run(fetches = [w])) #############
run()      
通过对“w”或“lreg/w”的fetch调用,我似乎无法获取变量的值(它实际上是op?“w”),如果我理解正确,这是因为“w”是在线性回归()中定义的,并且它不允许其命名空间运行()。但是,我可以通过对其变量名“lreg/vairable:0”的提取调用来访问“w”。优化器工作正常,更新应用完美


优化器是如何访问“w”并应用更新的?如果您能让我了解一下op“w”是如何在线性回归()和运行()之间共享的,那就太好了。

您创建的每个op和变量都是tensorflow中的一个节点。当您没有显式地创建一个图时,就像在您的例子中一样,就会使用一个默认的图

此行将w添加到默认图形中

 w = tf.Variable(np.random.normal())
此行访问图形以执行计算

loss_val, _ = session.run([loss, optimizer], feed_dict)
您可以像这样检查图形

tf.get_default_graph().as_graph_def()

非常感谢您的回复。我有一个后续问题:当从run()运行时,为什么print(session.run(fetches=[w])会抛出错误?NameError:未定义名称“w”。我想提醒您,print(session.run(fetches=['lreg/variable:0'])确实会为我获取“w”的值。您必须将python变量和tensorflow变量分开。仅仅因为tensorflow图中有一个名为w的变量,并不意味着python在当前范围内定义了一个名为w的变量。