Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Tensorflow 如何将作用域下的所有变量添加到某个集合中_Tensorflow - Fatal编程技术网

Tensorflow 如何将作用域下的所有变量添加到某个集合中

Tensorflow 如何将作用域下的所有变量添加到某个集合中,tensorflow,Tensorflow,在tensorflow python API中,tf.get_variable具有一个参数collections,用于将创建的变量添加到指定的集合中。但是tf。变量范围没有。 将变量范围下的所有变量添加到某个集合中的建议方法是什么?我不认为有直接的方法。您可以在Tensorflow的github问题跟踪器上提交功能请求 我可以建议您尝试两种解决方法: 迭代tf.all_variables()的结果,并提取名称类似于“../scope_name/…”的变量。作用域名称编码在变量名称中,由/字符分

在tensorflow python API中,tf.get_variable具有一个参数collections,用于将创建的变量添加到指定的集合中。但是tf。变量范围没有。
将变量范围下的所有变量添加到某个集合中的建议方法是什么?

我不认为有直接的方法。您可以在Tensorflow的github问题跟踪器上提交功能请求

我可以建议您尝试两种解决方法:

  • 迭代
    tf.all_variables()
    的结果,并提取名称类似于
    “../scope_name/…”
    的变量。作用域名称编码在变量名称中,由
    /
    字符分隔

  • 围绕tf.VariableScope和tf.get_variable()编写包装器,将在范围内创建的变量存储在数据结构中


我希望这有帮助

我成功地做到了这一点:

import tensorflow as tf

def var_1():
    with tf.variable_scope("foo") as foo_scope:
        assert foo_scope.name == "ll/foo"
        a = tf.get_variable("a", [2, 2])
    return foo_scope

def var_2(foo_scope):
    with tf.variable_scope("bar"):
        b = tf.get_variable("b", [2, 2])
        with tf.variable_scope("baz") as other_scope:
            c = tf.get_variable("c", [2, 2])
            assert other_scope.name == "ll/bar/baz"
            with tf.variable_scope(foo_scope) as foo_scope2:
                d = tf.get_variable("d", [2, 2])
                assert foo_scope2.name == "ll/foo"  # Not changed.

def main():
    with tf.variable_scope("ll"):
        scp = var_1()
        var_2(scp)
        all_default_global_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
        my_collection = tf.get_collection('my_collection') # create my collection
        ll_foo_variables = []
        for variable in all_default_global_variables:
            if "ll/foo" in variable.name:
                ll_foo_variables.append(variable)
        tf.add_to_collection('my_collection', ll_foo_variables)

        variables_in_my_collection = tf.get_collection_ref("my_collection")
        print(variables_in_my_collection)

main()
您可以在我的代码a、b、c和d中看到,只有a和d具有相同的作用域名称
ll/foo

过程: 首先,我添加默认情况下在tf.GraphKeys.GLOBAL_variables集合中创建的所有变量,然后创建一个名为my_collection的集合,然后仅将范围名称中带有“ll/foo”的变量添加到my_集合中

我得到的是我所期望的:

[[<tf.Variable 'll/foo/a:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'll/foo/d:0' shape=(2, 2) dtype=float32_ref>]]
[,]
将tensorflow作为tf导入
对于tf.global_变量(scope='model')中的var:
tf.add_到_集合(tf.GraphKeys.MODEL_变量,var)

如果您感兴趣的话,也可以迭代
可训练的\u变量,而不是使用
全局\u变量
。在这两种情况下,您不仅可以捕获使用
get_variable()
手动创建的变量,还可以捕获任何
tf.layers
调用创建的变量

您可以只获取范围内的所有变量,而不获取集合:

tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='my_scope')