Warning: file_get_contents(/data/phpspider/zhask/data//catemap/4/unix/3.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
C++ TensorFlow 0.12模型文件_C++_Tensorflow_Artificial Intelligence_Deep Learning_Tensorflow Serving - Fatal编程技术网

C++ TensorFlow 0.12模型文件

C++ TensorFlow 0.12模型文件,c++,tensorflow,artificial-intelligence,deep-learning,tensorflow-serving,C++,Tensorflow,Artificial Intelligence,Deep Learning,Tensorflow Serving,我训练模型并使用以下方法保存: saver = tf.train.Saver() saver.save(session, './my_model_name') 除了检查点文件(仅包含指向模型最新检查点的指针)之外,这将在当前路径中创建以下3个文件: my_model_name.meta my_model_name.index my_model_name.data-00000-of-00001 我想知道每个文件都包含什么 我想把这个模型加载到C++中并运行推理。该示例使用ReadBinaryPr

我训练模型并使用以下方法保存:

saver = tf.train.Saver()
saver.save(session, './my_model_name')
除了检查点文件(仅包含指向模型最新检查点的指针)之外,这将在当前路径中创建以下3个文件:

  • my_model_name.meta
  • my_model_name.index
  • my_model_name.data-00000-of-00001
  • 我想知道每个文件都包含什么

    <>我想把这个模型加载到C++中并运行推理。该示例使用
    ReadBinaryProto()
    从单个.bp文件加载模型。我想知道如何从这3个文件中加载它。下面的C++等价物是什么?

    new_saver = tf.train.import_meta_graph('./my_model_name.meta')
    new_saver.restore(session, './my_model_name')
    

    我现在自己也在努力解决这个问题,我发现现在做起来并不简单。关于这一主题,最常被引用的两个教程是: 和

    相当于

    new_saver = tf.train.import_meta_graph('./my_model_name.meta')
    new_saver.restore(session, './my_model_name')
    
    只是

    Status load_graph_status = LoadGraph(graph_path, &session);
    
    假设您“冻结了图形”(使用脚本将图形文件与检查点值组合)。
    另外,请参见此处的讨论:

    您的保护程序创建的内容称为“Checkpoint V2”,并在TF 0.12中引入

    <>我把它做得很好(虽然C++部分的文档很糟糕,所以我花了一天时间来解决)。有些人建议或,但实际上并不需要这些

    Python部分(保存)

    如果您使用
    tf.trainable_variables()
    创建
    Saver
    ,您可以为自己节省一些麻烦和存储空间。但是,可能一些更复杂的模型需要保存所有数据,然后将此参数删除到
    保存程序
    ,只需确保在创建图形后创建
    保存程序。为所有变量/层指定唯一的名称也是非常明智的,否则您可能会遇到不同的问题

    C++部分(推理)

    请注意,
    checkpointPath
    不是任何现有文件的路径,只是它们的公共前缀。如果您错误地将
    .index
    文件的路径放在那里,TF不会告诉您这是错误的,但由于未初始化变量,它会在推断过程中死亡

    #include <tensorflow/core/public/session.h>
    #include <tensorflow/core/protobuf/meta_graph.pb.h>
    
    using namespace std;
    using namespace tensorflow;
    
    ...
    // set up your input paths
    const string pathToGraph = "models/my-model.meta"
    const string checkpointPath = "models/my-model";
    ...
    
    auto session = NewSession(SessionOptions());
    if (session == nullptr) {
        throw runtime_error("Could not create Tensorflow session.");
    }
    
    Status status;
    
    // Read in the protobuf graph we exported
    MetaGraphDef graph_def;
    status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
    if (!status.ok()) {
        throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
    }
    
    // Add the graph to the session
    status = session->Create(graph_def.graph_def());
    if (!status.ok()) {
        throw runtime_error("Error creating graph: " + status.ToString());
    }
    
    // Read weights from the saved checkpoint
    Tensor checkpointPathTensor(DT_STRING, TensorShape());
    checkpointPathTensor.scalar<std::string>()() = checkpointPath;
    status = session->Run(
            {{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
            {},
            {graph_def.saver_def().restore_op_name()},
            nullptr);
    if (!status.ok()) {
        throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
    }
    
    // and run the inference to your liking
    auto feedDict = ...
    auto outputOps = ...
    std::vector<tensorflow::Tensor> outputTensors;
    status = session->Run(feedDict, outputOps, {}, &outputTensors);
    

    谢谢你,伊恩。我也发现了这一点:好发现。似乎他们做了我们正在做的事情,但是只用Python而不是C++。我现在正在研究这个问题:是什么引发了这个问题:你们最终成功了吗?我也在努力,我尝试了很多不同的方法,但是大多数都无法保存变量值,其他的崩溃或者给我一个常数输出在C++中…我终于做到了。当我有机会的时候,我会试着给你一个答案,除非有人比我快。@Ianonway抱歉,我更快:)我只是想知道你是否收敛到了同一点=)
    #include <tensorflow/core/public/session.h>
    #include <tensorflow/core/protobuf/meta_graph.pb.h>
    
    using namespace std;
    using namespace tensorflow;
    
    ...
    // set up your input paths
    const string pathToGraph = "models/my-model.meta"
    const string checkpointPath = "models/my-model";
    ...
    
    auto session = NewSession(SessionOptions());
    if (session == nullptr) {
        throw runtime_error("Could not create Tensorflow session.");
    }
    
    Status status;
    
    // Read in the protobuf graph we exported
    MetaGraphDef graph_def;
    status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
    if (!status.ok()) {
        throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
    }
    
    // Add the graph to the session
    status = session->Create(graph_def.graph_def());
    if (!status.ok()) {
        throw runtime_error("Error creating graph: " + status.ToString());
    }
    
    // Read weights from the saved checkpoint
    Tensor checkpointPathTensor(DT_STRING, TensorShape());
    checkpointPathTensor.scalar<std::string>()() = checkpointPath;
    status = session->Run(
            {{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
            {},
            {graph_def.saver_def().restore_op_name()},
            nullptr);
    if (!status.ok()) {
        throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
    }
    
    // and run the inference to your liking
    auto feedDict = ...
    auto outputOps = ...
    std::vector<tensorflow::Tensor> outputTensors;
    status = session->Run(feedDict, outputOps, {}, &outputTensors);
    
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('models/my-model.meta')
        saver.restore(sess, tf.train.latest_checkpoint('models/'))
        outputTensors = sess.run(outputOps, feed_dict=feedDict)