Python 如何在Tensorflow中恢复模型

Python 如何在Tensorflow中恢复模型,python,tensorflow,tensorflow-slim,Python,Tensorflow,Tensorflow Slim,首先,我使用tf.contrib.gan作为bellow训练了一个模型,并且我能够训练这个模型 tf.contrib.gan.gan_train( train_ops, hooks=( [tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message],

首先,我使用tf.contrib.gan作为bellow训练了一个模型,并且我能够训练这个模型

tf.contrib.gan.gan_train(
        train_ops,
        hooks=(
                [tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
                 tf.train.LoggingTensorHook([status_message], every_n_iter=10)] +
                sync_hooks),
        logdir=FLAGS.train_log_dir,
        master=FLAGS.master,
        is_chief=FLAGS.task == 0,
        config=conf
    )
然后我想对模型进行评估。尝试按以下方式恢复检查点时

with tf.name_scope('inputs'):
  real_images, one_hot_labels, _, num_classes = data_provider.provide_data(
    FLAGS.batch_size, FLAGS.dataset_dir)
  logits, end_points_des, feature, net = dcgan.discriminator(real_images)

  variables_to_restore = slim.get_model_variables()
  restorer = tf.train.Saver(variables_to_restore)

  with tf.Session() as sess:
          ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
          restorer.restore(sess, ckpt.model_checkpoint_path)
  def name_in_checkpoint(var):
      if "Discriminator/" in var.op.name:
         return var.op.name.replace("Discriminator/", "Discriminator/Discriminator/")

  logits, end_points_des, feature, net = dcgan.discriminator(real_images) 
  variables_to_restore = slim.get_model_variables()
  variables_to_restore = {name_in_checkpoint(var): var for var in variables_to_restore}
  restorer = tf.train.Saver(variables_to_restore)
我得到一个例外:

      2018-04-11 20:05:03.304089: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Discriminator/fully_connected_layer2/weights not found in checkpoint
      2018-04-11 20:05:03.304280: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Discriminator/conv0/BatchNorm/Discriminator/conv0/BatchNorm/moving_mean/local_step not found in checkpoint
      2018-04-11 20:05:03.304484: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Discriminator/conv0/BatchNorm/beta not found in checkpoint
      2018-04-11 20:05:03.305197: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Discriminator/fully_connected_layer2/biases not found in checkpoint

我使用的是TF 1.7rc1

实际上,生成的图形中有一个问题。这些是我为了解决这个问题而采取的步骤

步骤1: 使用以下代码打印checkpoit文件中的所有变量

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(file_name, '')
步骤2:然后我注意到每个键都是第一个作用域(“鉴别器”)的副本,这个作用域是设置好的,但当我尝试加载模型时,它并不包含那个部分。因此,我必须以以下方式删除该附加部分

with tf.name_scope('inputs'):
  real_images, one_hot_labels, _, num_classes = data_provider.provide_data(
    FLAGS.batch_size, FLAGS.dataset_dir)
  logits, end_points_des, feature, net = dcgan.discriminator(real_images)

  variables_to_restore = slim.get_model_variables()
  restorer = tf.train.Saver(variables_to_restore)

  with tf.Session() as sess:
          ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
          restorer.restore(sess, ckpt.model_checkpoint_path)
  def name_in_checkpoint(var):
      if "Discriminator/" in var.op.name:
         return var.op.name.replace("Discriminator/", "Discriminator/Discriminator/")

  logits, end_points_des, feature, net = dcgan.discriminator(real_images) 
  variables_to_restore = slim.get_model_variables()
  variables_to_restore = {name_in_checkpoint(var): var for var in variables_to_restore}
  restorer = tf.train.Saver(variables_to_restore)
第3步:然后我能够成功地加载模型,如下所示

ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
restorer.restore(sess, ckpt.model_checkpoint_path)
希望这将帮助可能遇到同样问题的人