C++ TensorFlow 0.12模型文件
我训练模型并使用以下方法保存:C++ TensorFlow 0.12模型文件,c++,tensorflow,artificial-intelligence,deep-learning,tensorflow-serving,C++,Tensorflow,Artificial Intelligence,Deep Learning,Tensorflow Serving,我训练模型并使用以下方法保存: saver = tf.train.Saver() saver.save(session, './my_model_name') 除了检查点文件(仅包含指向模型最新检查点的指针)之外,这将在当前路径中创建以下3个文件: my_model_name.meta my_model_name.index my_model_name.data-00000-of-00001 我想知道每个文件都包含什么 我想把这个模型加载到C++中并运行推理。该示例使用ReadBinaryPr
saver = tf.train.Saver()
saver.save(session, './my_model_name')
除了检查点文件(仅包含指向模型最新检查点的指针)之外,这将在当前路径中创建以下3个文件:
ReadBinaryProto()
从单个.bp文件加载模型。我想知道如何从这3个文件中加载它。下面的C++等价物是什么?
new_saver = tf.train.import_meta_graph('./my_model_name.meta')
new_saver.restore(session, './my_model_name')
我现在自己也在努力解决这个问题,我发现现在做起来并不简单。关于这一主题,最常被引用的两个教程是: 和 相当于
new_saver = tf.train.import_meta_graph('./my_model_name.meta')
new_saver.restore(session, './my_model_name')
只是
Status load_graph_status = LoadGraph(graph_path, &session);
假设您“冻结了图形”(使用脚本将图形文件与检查点值组合)。
另外,请参见此处的讨论:您的保护程序创建的内容称为“Checkpoint V2”,并在TF 0.12中引入 <>我把它做得很好(虽然C++部分的文档很糟糕,所以我花了一天时间来解决)。有些人建议或,但实际上并不需要这些 Python部分(保存) 如果您使用
tf.trainable_variables()
创建Saver
,您可以为自己节省一些麻烦和存储空间。但是,可能一些更复杂的模型需要保存所有数据,然后将此参数删除到保存程序
,只需确保在创建图形后创建保存程序。为所有变量/层指定唯一的名称也是非常明智的,否则您可能会遇到不同的问题
C++部分(推理)
请注意,checkpointPath
不是任何现有文件的路径,只是它们的公共前缀。如果您错误地将.index
文件的路径放在那里,TF不会告诉您这是错误的,但由于未初始化变量,它会在推断过程中死亡
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
using namespace std;
using namespace tensorflow;
...
// set up your input paths
const string pathToGraph = "models/my-model.meta"
const string checkpointPath = "models/my-model";
...
auto session = NewSession(SessionOptions());
if (session == nullptr) {
throw runtime_error("Could not create Tensorflow session.");
}
Status status;
// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok()) {
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}
// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
throw runtime_error("Error creating graph: " + status.ToString());
}
// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run(
{{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
{},
{graph_def.saver_def().restore_op_name()},
nullptr);
if (!status.ok()) {
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}
// and run the inference to your liking
auto feedDict = ...
auto outputOps = ...
std::vector<tensorflow::Tensor> outputTensors;
status = session->Run(feedDict, outputOps, {}, &outputTensors);
谢谢你,伊恩。我也发现了这一点:好发现。似乎他们做了我们正在做的事情,但是只用Python而不是C++。我现在正在研究这个问题:是什么引发了这个问题:你们最终成功了吗?我也在努力,我尝试了很多不同的方法,但是大多数都无法保存变量值,其他的崩溃或者给我一个常数输出在C++中…我终于做到了。当我有机会的时候,我会试着给你一个答案,除非有人比我快。@Ianonway抱歉,我更快:)我只是想知道你是否收敛到了同一点=)
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
using namespace std;
using namespace tensorflow;
...
// set up your input paths
const string pathToGraph = "models/my-model.meta"
const string checkpointPath = "models/my-model";
...
auto session = NewSession(SessionOptions());
if (session == nullptr) {
throw runtime_error("Could not create Tensorflow session.");
}
Status status;
// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok()) {
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}
// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
throw runtime_error("Error creating graph: " + status.ToString());
}
// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run(
{{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
{},
{graph_def.saver_def().restore_op_name()},
nullptr);
if (!status.ok()) {
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}
// and run the inference to your liking
auto feedDict = ...
auto outputOps = ...
std::vector<tensorflow::Tensor> outputTensors;
status = session->Run(feedDict, outputOps, {}, &outputTensors);
with tf.Session() as sess:
saver = tf.train.import_meta_graph('models/my-model.meta')
saver.restore(sess, tf.train.latest_checkpoint('models/'))
outputTensors = sess.run(outputOps, feed_dict=feedDict)