Tensorflow 为什么tf.case要构建两次可调用函数?

Tensorflow 为什么tf.case要构建两次可调用函数?,tensorflow,Tensorflow,我试图在一个网络中建立多个分支。所以我在代码中使用了tf.case。但是我发现tf.case总是构建最后一个可调用函数两次,这会导致变量错误:“变量XXX已经存在”(我通过slim创建了变量,变量范围“case/If_x”将不存在,这就是为什么我会得到错误)。这是一个带有输出的测试程序 import numpy as np import tensorflow as tf slim = tf.contrib.slim def fn1(X, Y): with tf.variable_sc

我试图在一个网络中建立多个分支。所以我在代码中使用了tf.case。但是我发现tf.case总是构建最后一个可调用函数两次,这会导致变量错误:“变量XXX已经存在”(我通过slim创建了变量,变量范围“case/If_x”将不存在,这就是为什么我会得到错误)。这是一个带有输出的测试程序

import numpy as np
import tensorflow as tf

slim = tf.contrib.slim

def fn1(X, Y):
    with tf.variable_scope("fn1", reuse=False):
       w = tf.Variable(1.0, name="w")
       #w = slim.variable(name="w", shape=())
    return X*w, Y*w

def fn2(X, Y):
    with tf.variable_scope("fn2", reuse=False):
       w = tf.Variable(2.0, name="w")
       #w = slim.variable(name="w", shape=())
    return X*w, Y*w

def fn3(X, Y):
    with tf.variable_scope("fn3", reuse=False):
       w = tf.Variable(3.0, name="w")
       #w = slim.variable(name="w", shape=())
    return X*w, Y*w

class Test:

    def __init__(self):
        self.Z = tf.placeholder(dtype=tf.int32, shape=())
        self.X = tf.Variable(1.0, name="X")
        self.Y = tf.Variable(2.0, name="Y")

    def build(self):
        self.result = tf.case(
            pred_fn_pairs=[
                    (tf.equal(self.Z, 10), lambda : fn3(self.X, self.Y)),
                    (tf.equal(self.Z, 20), lambda : fn2(self.X, self.Y)),
                    (tf.equal(self.Z, 30), lambda : fn1(self.X, self.Y))],
                    exclusive=False)

test = Test()
test.build()

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    tvars = tf.trainable_variables()
    tvars_vals = sess.run(tvars)

    for var, val in zip(tvars, tvars_vals):
        print(var.name) 

    aa = sess.run(test.result, feed_dict={test.Z:20})
    print aa
输出为:

X:0
Y:0
case/If_0/fn1/w:0
case/If_0/fn1_1/w:0
case/If_1/fn2/w:0
case/If_2/fn3/w:0
(2.0, 4.0)

我无法重现这个问题,您是否可以重新启动jupyter内核或python环境,然后重试,这样变量将被删除discarded@naveenmarri您使用哪一个来创建varaible tf.Variable()或slim.Variable()?我在spyder中多次尝试这个程序(每次都重新启动内核)。如果我使用tf.Variable(),程序可以成功运行,显示fn1是我输入的两倍。如果我使用slim.variable()。它会出错。