Python中的Tensorflow Java Api `toGraphDef`等价物是什么?
我正在使用TensorFlowJavaAPI将已经创建的Tensorflow模型加载到JVM中。 我以此为例: 下面是我的简单scala代码:Python中的Tensorflow Java Api `toGraphDef`等价物是什么?,java,scala,tensorflow,java-native-interface,tensorflow-serving,Java,Scala,Tensorflow,Java Native Interface,Tensorflow Serving,我正在使用TensorFlowJavaAPI将已经创建的Tensorflow模型加载到JVM中。 我以此为例: 下面是我的简单scala代码: import java.nio.file.{Files, Path, Paths} import org.tensorflow.{Graph, Session, Tensor} def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path) val graphDef
import java.nio.file.{Files, Path, Paths}
import org.tensorflow.{Graph, Session, Tensor}
def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path)
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"))
val g = new Graph()
g.importGraphDef(graphDef)
val session = new Session(g)
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0))
如何保存模型以将会话和图形存储在同一文件中。如上面的“路径到描述模型的单个文件”中所述
它提到:
图形的序列化表示,通常称为
GraphDef,可以由toGraphDef()和其他版本中的等效项生成
语言API
其他语言API中的等价物是什么?我不觉得这很明显
注意:我已经查看了tensorflow下的mnist_saved_model.py,但是通过该过程保存它会给我一个
.pb
文件和一个variables
文件夹。当尝试加载.pb
文件时,我得到:java.lang.IllegalArgumentException:Invalid GraphDef
当前使用tensorflow的java API,我只找到了如何将图形保存为GraphDef(即没有其变量和元数据)。只需将数组[Byte]写入文件即可:
Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef)
这里的myGraph
是来自的java对象
我建议使用这里定义的API从Python API保存您的模型。它会将您的模型保存在一个文件夹中,其中包含.pb文件中的序列化图形和文件夹中的变量。请注意您在scala/java代码中使用的tag_常量,以加载带有变量的模型。然后,使用java api中的java类轻松地加载带有变量的图形和会话。它返回一个包装器,其中包含包含变量值的图形和会话:
val model = SavedModelBundle.load(modelDir, modelTag)
如果您已经尝试过这个方法,也许可以共享您的代码,看看它为什么返回无效的GraphDef
另一个选项是冻结图形,即将变量节点转换为常量节点,使所有内容都包含在.pb文件中。关于冻结部分的更多信息我曾尝试使用,将图形加载到会话中是有意义的,但在运行会话时,变量不在那里。
SavedModelBundle
尚未发布,但我可以尝试重新编译并使用它。我将尝试冻结我的模型,看看是否有效。谢谢你说得对@DanielHasegan,你只需要从源代码构建JavaAPI依赖关系,我设法为冻结的模型提供服务,还需要将我的本机库重新编译到最新的1.1版本。谢谢你的帮助