使用Tensorflow估计器API仅为验证集图像的子集创建图像摘要

使用Tensorflow估计器API仅为验证集图像的子集创建图像摘要,tensorflow,Tensorflow,我正在尝试添加图像摘要操作,以可视化我的网络如何从验证集重建输入。但是,由于验证集中的图像太多,我只想绘制其中的一小部分 我设法通过手动训练循环实现了这一点,但我很难通过新的Tensorflow估计器/实验/数据集API实现这一点。有人做过类似的事情吗?实验和估计器都是高级TensorFlow API。虽然您可能可以使用钩子解决问题,但如果您希望对培训过程中发生的事情有更多的控制,那么不使用这些API可能会更容易 也就是说,您仍然可以使用Dataset API,它将为您带来许多有用的特性 要解决

我正在尝试添加图像摘要操作,以可视化我的网络如何从验证集重建输入。但是,由于验证集中的图像太多,我只想绘制其中的一小部分


我设法通过手动训练循环实现了这一点,但我很难通过新的Tensorflow估计器/实验/数据集API实现这一点。有人做过类似的事情吗?

实验和估计器都是高级TensorFlow API。虽然您可能可以使用钩子解决问题,但如果您希望对培训过程中发生的事情有更多的控制,那么不使用这些API可能会更容易

也就是说,您仍然可以使用Dataset API,它将为您带来许多有用的特性

要解决Dataset API的问题,需要在训练循环中在训练数据集和验证数据集之间切换

一种方法是使用可反馈迭代器。有关更多详细信息,请参见此处:

您还可以在中看到一个完整的示例,该示例使用Dataset API在培训和验证之间切换

简言之,在创建了train_数据集和val_数据集之后,您的训练循环可以是这样的:

# create TensorFlow Iterator objects
training_iterator = val_dataset.make_initializable_iterator()
val_iterator = val_dataset.make_initializable_iterator()

with tf.Session() as sess:

  # Initialize variables
  init = tf.global_variables_initializer()
  sess.run(init)

  # Create training data and validation data handles
  training_handle = sess.run(training_iterator.string_handle())
  validation_handle = sess.run(val_iterator.string_handle())

  for epoch in range(number_of_epochs):

    # Tell iterator to go to beginning of dataset
    sess.run(training_iterator.initializer)

    print ("Starting epoch: ", epoch)

    # iterate over the training dataset and train
    while True:
        try:
            sess.run(train_op, feed_dict={handle: training_handle})
        except tf.errors.OutOfRangeError:
            # End of epoch
            break              

    # Tell validation iterator to go to beginning of dataset
    sess.run(val_iterator.initializer)

    # run validation on only 10 examples
    for i in range(10):
        my_value = sess.run(my_validation_op, feed_dict={handle: validation_handle}))
        # Do whatever you want with my_value
        ...

我想出了一个使用估计器/实验API的解决方案

首先,您需要修改数据集输入,以便不仅提供标签和功能,而且为每个示例提供某种形式的标识符(在我的示例中,它是一个文件名)。然后在hyperparameters字典(
params
参数)中,您需要指定要打印的验证样本。您还必须在这些参数中传递
model\u dir
。例如:

params = tf.contrib.training.HParams(
        model_dir=model_dir,
        images_to_plot=["100307_EMOTION.nii.gz", "100307_FACE-SHAPE.nii.gz",
                        "100307_GAMBLING.nii.gz", "100307_RELATIONAL.nii.gz",
                        "100307_SOCIAL.nii.gz"]
    )

learn_runner.run(
        experiment_fn=experiment_fn,
        run_config=run_config,
        schedule="train_and_evaluate",
        hparams=params
    )
设置此选项后,您可以在
模型中创建条件摘要操作,并创建一个求值挂钩,将它们包含在输出中

if mode == tf.contrib.learn.ModeKeys.EVAL:
    summaries = []
    for image_to_plot in params.images_to_plot:
        is_to_plot = tf.equal(tf.squeeze(filenames), image_to_plot)

        summary = tf.cond(is_to_plot,
                          lambda: tf.summary.image('predicted', predictions),
                          lambda: tf.summary.histogram("ignore_me", [0]),
                          name="%s_predicted" % image_to_plot)
        summaries.append(summary)

    evaluation_hooks = [tf.train.SummarySaverHook(
        save_steps=1,
        output_dir=os.path.join(params.model_dir, "eval"),
        summary_op=tf.summary.merge(summaries))]
else:
    evaluation_hooks = None
请注意,总结必须是有条件的-我们要么绘制图像(计算成本高),要么保存常量(计算成本低)。我选择在虚拟摘要中使用
直方图
而不是
标量
,以避免弄乱我的tensorboard仪表板

最后,您需要在“model_fn”的返回对象中传递钩子

return tf.estimator.EstimatorSpec(
    mode=mode,
    predictions=predictions,
    loss=loss,
    train_op=train_op,
    evaluation_hooks=evaluation_hooks
)
请注意,在评估模型时,这仅适用于批大小为1的情况(这应该不是问题)