Python tf.train.Saver-在不同的机器上加载最新的检查点

Python tf.train.Saver-在不同的机器上加载最新的检查点,python,tensorflow,Python,Tensorflow,我有一个经过训练的模型,它是使用tf.train.Saver保存的,生成了4个相关文件 检查点 iter-315000型。数据-00000-of-00001 iter-315000型索引 model_iter-315000.meta 因为它是通过docker容器生成的,所以机器本身和docker上的路径是不同的,就像我们在两台不同的机器上工作一样 我正在尝试在容器外部加载保存的模型 当我运行以下命令时 sess = tf.Session() saver = tf.train.import_m

我有一个经过训练的模型,它是使用
tf.train.Saver
保存的,生成了4个相关文件

  • 检查点
  • iter-315000型。数据-00000-of-00001
  • iter-315000型索引
  • model_iter-315000.meta
因为它是通过docker容器生成的,所以机器本身和docker上的路径是不同的,就像我们在两台不同的机器上工作一样

我正在尝试在容器外部加载保存的模型

当我运行以下命令时

sess = tf.Session()
saver = tf.train.import_meta_graph('path_to_.meta_file_on_new_machine')  # Works
saver.restore(sess, tf.train.latest_checkpoint('path_to_ckpt_dir_on_new_machine')  # Fails
错误是

tensorflow.python.framework.errors\u impl.NotFoundError:旧机器上的路径;没有这样的文件或目录

尽管我在调用
tf.train.latest_checkpoint
时提供了新路径,但我还是得到了错误,它显示了旧路径上的路径


如何解决此问题?

如果打开
检查点
文件,您将看到如下内容:

model_checkpoint_path: "/PATH/ON/OLD/MACHINE/model.ckpt-315000"
all_model_checkpoint_paths: "/PATH/ON/OLD/MACHINE/model.ckpt-300000"
all_model_checkpoint_paths: "/PATH/ON/OLD/MACHINE/model.ckpt-285000"
[...]
只需删除
/PATH/ON/OLD/MACHINE/
,或将其替换为
/PATH/ON/NEW/MACHINE/
,就可以开始了

编辑: 将来,在创建
tf.train.Saver
时,应该使用
save\u relative\u paths
选项。引述:

保存相对路径:如果为True,则将相对路径写入 检查点状态文件。如果用户想要复制 检查点目录并从复制的目录重新加载

“检查点”文件是一个索引文件,它本身具有嵌入其中的路径。在文本编辑器中打开它,并将路径更改为正确的新路径

或者,使用加载特定的检查点,而不依赖TensorFlow为您查找最新的检查点。在这种情况下,它不会引用“checkpoint”文件,不同的路径也不会有问题


或者编写一个小脚本来修改“checkpoint”的内容。

这是一种不需要编辑检查点文件或手动查看检查点目录的方法。如果我们知道检查点前缀的名称,我们可以使用正则表达式,并假设tensorflow将最新的检查点写入
checkpoint
文件的第一行:

import tensorflow as tf
import os
import re


def latest_checkpoint(ckpt_dir, ckpt_prefix="model.ckpt", return_relative=True):
    if return_relative:
        with open(os.path.join(ckpt_dir, "checkpoint")) as f:
            text = f.readline()
        pattern = re.compile(re.escape(ckpt_prefix + "-") + r"[0-9]+")
        basename = pattern.findall(text)[0]
        return os.path.join(ckpt_dir, basename)
    else:
        return tf.train.latest_checkpoint(ckpt_dir)

必须有一种更通用的方法来完成,这必须通过代码来完成好的,在我的回答中添加了对tf.train.load_checkpoint()的引用必须有一种更通用的方法来完成,这必须通过代码来完成。好的,你可以使用
sed
,可能类似于
sed-i“s”/PATH/to/OLD/MACHINE/??g“checkpoint
,或者在python中使用正则表达式执行。