Warning: file_get_contents(/data/phpspider/zhask/data//catemap/9/java/389.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
运行用python编写的tensorflow模型,用于从java进行培训和预测_Java_Python_Tensorflow_Deep Learning_Image Recognition - Fatal编程技术网

运行用python编写的tensorflow模型,用于从java进行培训和预测

运行用python编写的tensorflow模型,用于从java进行培训和预测,java,python,tensorflow,deep-learning,image-recognition,Java,Python,Tensorflow,Deep Learning,Image Recognition,我已经为自己的数据集重新训练了初始模型。这个模型是用python构建的,我现在将图形保存为.pb文件,标签文件保存为.txt。现在,我需要通过java对图像使用此模型进行预测。谁能帮我一下吗?TensorFlow团队正在开发Java接口,但它还不稳定。您可以在此处找到现有代码:并在此处跟踪其开发的更新。您可以看一看,看看它目前是如何使用的(尽管如前所述,这在将来可能会改变)。基本上,您需要将二进制保存的图形加载到graph对象中,用它创建一个Session,并使用适当的值(如Tensors)运行

我已经为自己的数据集重新训练了初始模型。这个模型是用python构建的,我现在将图形保存为.pb文件,标签文件保存为.txt。现在,我需要通过java对图像使用此模型进行预测。谁能帮我一下吗?

TensorFlow团队正在开发Java接口,但它还不稳定。您可以在此处找到现有代码:并在此处跟踪其开发的更新。您可以看一看,看看它目前是如何使用的(尽管如前所述,这在将来可能会改变)。基本上,您需要将二进制保存的图形加载到
graph
对象中,用它创建一个
Session
,并使用适当的值(如
Tensor
s)运行它,以接收带有输出的
列表。将源代码中的示例组合在一起:

import java.nio.file.Files;
import java.nio.file.Paths;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

try (Graph graph = new Graph()) {
    graph.importGraphDef(Files.readAllBytes(Paths.get("saved_model.pb"));
    try (Session sess = new Session(graph)) {
        try (Tensor x = Tensor.create(1.0f);
             Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
            System.out.println(y.floatValue());
        }
    }
}

我使用的代码读取一个
protobuf
文件,以
.pb
结尾

try (SavedModelBundle b = SavedModelBundle.load("/tmp/model", "serve")) {
    Session sess = b.session();
    ...
    float[][]matrix = sess.runner()
        .feed("x", input)
        .feed("keep_prob", keep_prob)
        .fetch("y_conv")
        .run()
        .get(0)
        .copyTo(new float[1][10]);
    ...
}
我用来保存它的python代码是:

  signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs = {'x': tf.saved_model.utils.build_tensor_info(x)},
    outputs = {'y_conv': tf.saved_model.utils.build_tensor_info(y_conv)},
  )
  builder = tf.saved_model.builder.SavedModelBuilder("/tmp/model" )
  builder.add_meta_graph_and_variables(sess, 
       [tf.saved_model.tag_constants.SERVING],
       signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}
   )
  builder.save()