如何使用TensorFlow java api删除预训练模型的输出层?

如何使用TensorFlow java api删除预训练模型的输出层?,java,python,tensorflow,deep-learning,keras,Java,Python,Tensorflow,Deep Learning,Keras,我有像Inception-v3这样的预先训练过的模型。我想移除输出层,并将其用于图像认知。以下是tensorflow给出的示例: 就像python框架Keras一样,它有一个类似于model.layers.pop()的方法。我试着用tensorflow java api来实现它。首先,我尝试使用dl4j,但当我导入keras模型时,我得到了如下错误: 2017-06-15 21:15:43 INFO KerasInceptionV3Net:52 - Importing Inception mo

我有像Inception-v3这样的预先训练过的模型。我想移除输出层,并将其用于图像认知。以下是tensorflow给出的示例:

就像python框架Keras一样,它有一个类似于
model.layers.pop()
的方法。我试着用tensorflow java api来实现它。首先,我尝试使用dl4j,但当我导入keras模型时,我得到了如下错误:

2017-06-15 21:15:43 INFO  KerasInceptionV3Net:52 - Importing Inception model from data/inception-model.json
2017-06-15 21:15:43 INFO  KerasInceptionV3Net:53 - Importing Weights model from data/inception_v3_complete
Exception in thread "main" java.lang.RuntimeException: Unknown exception.
at org.bytedeco.javacpp.hdf5$H5File.allocate(Native Method)
at org.bytedeco.javacpp.hdf5$H5File.<init>(hdf5.java:12713)
at org.deeplearning4j.nn.modelimport.keras.Hdf5Archive.<init>(Hdf5Archive.java:61)
at org.deeplearning4j.nn.modelimport.keras.KerasModel$ModelBuilder.weightsHdf5Filename(KerasModel.java:603)
at org.deeplearning4j.nn.modelimport.keras.KerasModelImport.importKerasModelAndWeights(KerasModelImport.java:176)
at edu.usc.irds.dl.dl4j.examples.KerasInceptionV3Net.<init>(KerasInceptionV3Net.java:55)
at edu.usc.irds.dl.dl4j.examples.KerasInceptionV3Net.main(KerasInceptionV3Net.java:108)
HDF5-DIAG: Error detected in HDF5 (1.10.0-patch1) thread 0:
#000: C:\autotest\HDF5110ReleaseRWDITAR\src\H5F.c line 579 in H5Fopen(): unable to open file
major: File accessibilty
minor: Unable to open file
#001: C:\autotest\HDF5110ReleaseRWDITAR\src\H5Fint.c line 1100 in H5F_open(): unable to open file: time = Thu Jun 15 21:15:44 2017,name = 'data/inception_v3_complete', tent_flags = 0
major: File accessibilty
minor: Unable to open file
#002: C:\autotest\HDF5110ReleaseRWDITAR\src\H5FD.c line 812 in H5FD_open(): open failed
major: Virtual File Layer
minor: Unable to initialize object
#003: C:\autotest\HDF5110ReleaseRWDITAR\src\H5FDsec2.c line 348 in H5FD_sec2_open(): unable to open file: name = 'data/inception_v3_complete', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0
major: File accessibilty
minor: Unable to open file
我得到的模型是.pb文件,但当我把它放到张量示例中时,我得到了以下错误:

Exception in thread "main" java.lang.IllegalArgumentException: You must feed a value for placeholder tensor 'batch_normalization_1/keras_learning_phase' with dtype bool
 [[Node: batch_normalization_1/keras_learning_phase = Placeholder[dtype=DT_BOOL, shape=<unknown>, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:285)
at org.tensorflow.Session$Runner.run(Session.java:235)
at com.dlut.cmh.sheng.LabelImage.executeInceptionGraph(LabelImage.java:98)
at com.dlut.cmh.sheng.LabelImage.main(LabelImage.java:51)
线程“main”java.lang.IllegalArgumentException中的异常:必须为带有数据类型bool的占位符张量“批处理规范化阶段/keras学习阶段”提供一个值
[[Node:batch\u normalization\u 1/keras\u learning\u phase=占位符[dtype=DT\u BOOL,shape=,\u device=“/job:localhost/replica:0/task:0/cpu:0”]()]
在org.tensorflow.Session.run(本机方法)
访问org.tensorflow.Session.access$100(Session.java:48)
位于org.tensorflow.Session$Runner.runHelper(Session.java:285)
位于org.tensorflow.Session$Runner.run(Session.java:235)
位于com.dlut.cmh.sheng.LabelImage.executeInceptionGraph(LabelImage.java:98)
位于com.dlut.cmh.sheng.LabelImage.main(LabelImage.java:51)

我不知道怎么解决这个问题。有人能帮我吗?或者您有其他方法来执行此操作?

您从TensorFlow Java API获得的错误消息:


线程“main”java.lang.IllegalArgumentException中的异常:您必须使用dtype bool为占位符张量“批处理\规范化\ 1/keras\学习\阶段”提供一个值
[[Node:batch\u normalization\u 1/keras\u learning\u phase=占位符[dtype=DT\u BOOL,shape=,\u device=“/job:localhost/replica:0/task:0/cpu:0”]()]

建议该模型的构造方式要求您为名为
batch\u normalization\u 1/keras\u learning\u phase
的张量输入布尔值

因此,您必须通过更改以下内容将其包括在您的通话中:

try (Session s = new Session(g);
     Tensor result = s.runner().feed("input",image).fetch("output").run().get(0)) {
例如:

try (Session s = new Session(g);
     Tensor learning_phase = Tensor.create(false);
     Tensor result = s.runner().feed("input", image).feed("batch_normalization_1/keras_learning_phase", learning_phase).fetch("output").run().get(0)) {
馈送和获取的节点的名称取决于模型,因此“输入”和“输出”节点的名称也可能不同

您可能还想考虑使用(参见)


希望这有帮助

我不会将其作为答案添加,但会对您使用的deeplearning4j进行评论。您使用了什么版本的deeplearning4j?在最新版本中,我们修复了大量有关模型导入的问题。ApacheTika项目很好地使用了它。你能在一个问题上给我们一些反馈,而不仅仅是切换吗?谢谢如果你看看我们的例子,我们的迁移学习api可以很好地处理这个问题。@AdamGibson我在GITTER中问过这个问题,并提出了一个问题。我按照建议,下载了他提供的新演示。但是当我运行演示时,我遇到了一个例外。qq群里的人说很多演示都因为keras2模型而无法成功运行。但我确实在演示中使用了这个模型,我不知道keras1模型和keras2模型之间的区别。下载keras1.x并在那里试用,它应该可以正常运行。
try (Session s = new Session(g);
     Tensor learning_phase = Tensor.create(false);
     Tensor result = s.runner().feed("input", image).feed("batch_normalization_1/keras_learning_phase", learning_phase).fetch("output").run().get(0)) {