Python 在TensorFlow中重命名已保存模型的变量范围

Python 在TensorFlow中重命名已保存模型的变量范围,python,tensorflow,Python,Tensorflow,是否可以在tensorflow中重命名给定模型的变量范围 例如,我根据教程为MNIST数字创建了逻辑回归模型: with tf.variable_scope('my-first-scope'): NUM_IMAGE_PIXELS = 784 NUM_CLASS_BINS = 10 x = tf.placeholder(tf.float32, shape=[None, NUM_IMAGE_PIXELS]) y_ = tf.placeholder(tf.float32

是否可以在tensorflow中重命名给定模型的变量范围

例如,我根据教程为MNIST数字创建了逻辑回归模型:

with tf.variable_scope('my-first-scope'):
    NUM_IMAGE_PIXELS = 784
    NUM_CLASS_BINS = 10
    x = tf.placeholder(tf.float32, shape=[None, NUM_IMAGE_PIXELS])
    y_ = tf.placeholder(tf.float32, shape=[None, NUM_CLASS_BINS])

    W = tf.Variable(tf.zeros([NUM_IMAGE_PIXELS,NUM_CLASS_BINS]))
    b = tf.Variable(tf.zeros([NUM_CLASS_BINS]))

    y = tf.nn.softmax(tf.matmul(x,W) + b)
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
    saver = tf.train.Saver([W, b])

... # some training happens

saver.save(sess, 'my-model')
现在,我想将保存的模型重新加载到变量范围中,然后将所有内容再次保存到新文件中,并保存在新变量范围下,即
'my-first-scope'
中,您可以使用和实现以下目标:

with tf.Graph().as_default(), tf.Session().as_default() as sess:
  with tf.variable_scope('my-first-scope'):
    NUM_IMAGE_PIXELS = 784
    NUM_CLASS_BINS = 10
    x = tf.placeholder(tf.float32, shape=[None, NUM_IMAGE_PIXELS])
    y_ = tf.placeholder(tf.float32, shape=[None, NUM_CLASS_BINS])

    W = tf.Variable(tf.zeros([NUM_IMAGE_PIXELS,NUM_CLASS_BINS]))
    b = tf.Variable(tf.zeros([NUM_CLASS_BINS]))

    y = tf.nn.softmax(tf.matmul(x,W) + b)
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
    saver = tf.train.Saver([W, b])
  sess.run(tf.global_variables_initializer())
  saver.save(sess, 'my-model')

vars = tf.contrib.framework.list_variables('.')
with tf.Graph().as_default(), tf.Session().as_default() as sess:

  new_vars = []
  for name, shape in vars:
    v = tf.contrib.framework.load_variable('.', name)
    new_vars.append(tf.Variable(v, name=name.replace('my-first-scope', 'my-second-scope')))

  saver = tf.train.Saver(new_vars)
  sess.run(tf.global_variables_initializer())
  saver.save(sess, 'my-new-model')

根据keveman的回答,我创建了一个python脚本,您可以执行该脚本来重命名任何TensorFlow检查点的变量:

可以替换变量名称中的子字符串,并为所有名称添加前缀。使用

python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir
使用可选参数

--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run
以下是脚本的核心功能:

def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run=False):
    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
    with tf.Session() as sess:
        for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
            # Load the variable
            var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)

            # Set the new name
            new_name = var_name
            if None not in [replace_from, replace_to]:
                new_name = new_name.replace(replace_from, replace_to)
            if add_prefix:
                new_name = add_prefix + new_name

            if dry_run:
                print('%s would be renamed to %s.' % (var_name, new_name))
            else:
                print('Renaming %s to %s.' % (var_name, new_name))
                # Rename the variable
                var = tf.Variable(var, name=new_name)

        if not dry_run:
            # Save the variables
            saver = tf.train.Saver()
            sess.run(tf.global_variables_initializer())
            saver.save(sess, checkpoint.model_checkpoint_path)
例如:

python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir --replace_from=scope1 --replace_to=scope1/model --add_prefix=abc/

将变量
scope1/Variable1
重命名为
abc/scope1/model/Variable1
另一个简单脚本,用于重命名变量,并以这种方式更改其作用域名称:

import tensorflow as tf

OLD_CHECKPOINT_FILE = "model.ckpt"
NEW_CHECKPOINT_FILE = "model_renamed.ckpt"

vars_to_rename = {
    "scope_1/var1": "scope_2/var1",
    "scope_1/var2": "scope_2/var2",
}
new_checkpoint_vars = {}
reader = tf.train.NewCheckpointReader(OLD_CHECKPOINT_FILE)

for old_name in reader.get_variable_to_shape_map():
    if old_name in vars_to_rename:
        new_name = vars_to_rename[old_name]
    else:
        new_name = old_name
    new_checkpoint_vars[new_name] = tf.Variable(reader.get_tensor(old_name))

init = tf.global_variables_initializer()
saver = tf.train.Saver(new_checkpoint_vars)

with tf.Session() as sess:
    sess.run(init)
    saver.save(sess, NEW_CHECKPOINT_FILE)    

这要求您已经使用以前的作用域名称构建了图和所有内容,因为要恢复检查点,您需要定义图。如果您只有检查点文件,是否可以替换其中的作用域名称?我在脚本中遇到此错误:ValueError:在给定目录中找不到“checkpoint”文件或检查点/fi@ryuzakinho,您需要指定包含检查点文件的目录。有关更多信息,请参阅。实际上,出于某种原因,我已将write_state=False写入。因此,它没有创建检查点文件。当图形太大时,我遇到了一个错误
tensorflow.python.framework.errors\u impl.InvalidArgumentError:无法序列化tensorflow.GraphDef类型的协议缓冲区,因为序列化的大小(3363746871字节)将大于限制(2147483647字节)
实际上,这种方法似乎会导致.meta文件大小膨胀。