无法在Java API中运行Tensorflow预测

无法在Java API中运行Tensorflow预测,java,python,tensorflow,Java,Python,Tensorflow,我试图在一个模型上执行一个预测,我使用TensorFlow对AlexNet进行了微调 我在Python中使用tf.saved_model.builder.SavedModelBuilder保存了模型,并使用SavedModelBundle.load在Java中加载了模型。 守则的主要部分是: SavedModelBundle smb = SavedModelBundle.load(path, "serve"); Session s = smb.session(); by

我试图在一个模型上执行一个预测,我使用TensorFlow对AlexNet进行了微调

我在Python中使用tf.saved_model.builder.SavedModelBuilder保存了模型,并使用SavedModelBundle.load在Java中加载了模型。 守则的主要部分是:

    SavedModelBundle smb = SavedModelBundle.load(path, "serve");
    Session s = smb.session();
    byte[] imageBytes = readAllBytesOrExit(Paths.get(path));
    Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes);
    Tensor result = s.runner().feed("input_tensor", image).fetch("fc8/fc8").run().get(0);
    final long[] rshape = result.shape();
    if (result.numDimensions() != 2 || rshape[0] != 1) {
        throw new RuntimeException(
                String.format(
                        "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
                        Arrays.toString(rshape)));
    }
    int nlabels = (int) rshape[1];
    float [] a =  result.copyTo(new float[1][nlabels])[0];`
我得到一个例外:

线程主java.lang.IllegalArgumentException中出现异常:必须为带有dtype float的占位符张量“占位符_1”提供一个值 [[Node:Placeholder_1=Placeholder_output_shapes=[[]],dtype=DT_FLOAT,shape=[],_device=/job:localhost/replica:0/task:0/cpu:0]]

我看到上面的代码对一些人有用,但我不知道这里缺少了什么。
请注意,网络熟悉节点输入\u张量和fc8/fc8,因为它没有说它不知道它们。

从错误消息中,您使用的模型似乎希望输入另一个值,该值在图中的节点名为占位符\u 1,预期类型为浮点标量张量

似乎您已经定制了您的模型,而不是按照链接到逐字记录的文章进行操作。这就是说,本文显示了需要输入的多个占位符,一个用于图像,另一个用于控制退出。在本条中定义为:

keep_prob = tf.placeholder(tf.float32)
需要输入这个占位符的值。如果您正在进行推理,则需要将keep_prob设置为1.0。比如:

Tensor keep_prob = Tensor.create(1.0f);
Tensor result = s.runner()
  .feed("input_tensor", image)
  .feed("Placeholder_1", keep_prob)
  .fetch("fc8/fc8")
  .run()
  .get(0);
希望有帮助