Scikit learn Tensorflow Scikit Flow get GraphDef for Android(保存*.pb文件)

Scikit learn Tensorflow Scikit Flow get GraphDef for Android(保存*.pb文件),scikit-learn,tensorflow,tensorflow-serving,Scikit Learn,Tensorflow,Tensorflow Serving,我想在Android应用程序中使用我的Tensorflow算法。Tensorflow Android示例首先下载包含模型定义和权重的GraphDef(在*.pb文件中)。现在,这应该来自我的Scikit流算法(Tensorflow的一部分) 乍一看,您只需说classifier.save('model/')就很容易了,但保存到该文件夹中的文件不是*.ckpt、*.def,当然也不是*.pb。相反,您必须处理一个*.pbtxt和一个检查点(无结尾)文件 我在那里呆了很久了。下面是导出某些内容的代码

我想在Android应用程序中使用我的Tensorflow算法。Tensorflow Android示例首先下载包含模型定义和权重的GraphDef(在*.pb文件中)。现在,这应该来自我的Scikit流算法(Tensorflow的一部分)

乍一看,您只需说classifier.save('model/')就很容易了,但保存到该文件夹中的文件不是*.ckpt、*.def,当然也不是*.pb。相反,您必须处理一个*.pbtxt和一个检查点(无结尾)文件

我在那里呆了很久了。下面是导出某些内容的代码示例:

#imports
import tensorflow as tf
import tensorflow.contrib.learn as skflow
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics

#skflow example
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest")
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)
您获得的文件是:

  • 检查站
  • graph.pbtxt
  • model.ckpt-1.meta
  • 型号:ckpt-1-00000-of-00001
  • model.ckpt-200.meta
  • 型号:ckpt-200-00000-of-00001

我发现许多可能的解决方法都需要在变量中包含GraphDef(不知道如何使用Scikit流)。或者使用Scikit Flow似乎不需要的Tensorflow会话。

要另存为pb文件,需要从构造的图形中提取图形定义。你也可以这样做--

如果要将经过训练的变量转换为常量(以避免使用ckpt文件加载权重),可以使用:

output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [final_tensor_name])

希望这有帮助

表示“未定义名称‘标志’”。如果我将目标路径改为*.pb文件,它似乎可以工作。我也不知道你所说的[网络前缀+最终张量名称]是什么意思。似乎我可以把输出占位符放在那里(因为字符串似乎需要一个名称)@CodingYourLife:yes。我编辑了更改。用张量(输出节点)的张量或名称替换最终的张量或名称,它将正确导出输出图。您找到解决方案了吗?我决定使用Scikit Flow进行实验(我的NN需要多少层等),然后用纯张量流重新创建模型。然后,我创建了第二个模型,将已经训练好的权重作为常量(切换到iOS,但可能与Android相同),从而避免了整个冻结图bazel。这不是一个真正的建议,只是我走的路
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [final_tensor_name])