Java中向Tensorflow模型传递数据

Java中向Tensorflow模型传递数据,java,tensorflow,Java,Tensorflow,我试图使用一个我在python中训练过的Tensorflow模型在Scala中对数据进行评分(使用TFJavaAPI)。对于模型,我使用了这个,唯一的变化是我从export\u savedmodel中删除了asText=True 我的Scala片段: val b = SavedModelBundle.load("/tensorflow/tf-estimator-tutorials/trained_models/reg-model-01/export/1531933435/", "serve"

我试图使用一个我在python中训练过的Tensorflow模型在Scala中对数据进行评分(使用TFJavaAPI)。对于模型,我使用了这个,唯一的变化是我从
export\u savedmodel
中删除了
asText=True

我的Scala片段:

  val b = SavedModelBundle.load("/tensorflow/tf-estimator-tutorials/trained_models/reg-model-01/export/1531933435/", "serve")
  val s = b.session()

  // output = predictor_fn({'csv_rows': ["0.5,1,ax01,bx02", "-0.5,-1,ax02,bx02"]})
  val input = "0.5,1,ax01,bx02"

  val inputTensor = Tensor.create(input.getBytes("UTF-8"))

  val result = s.runner()
    .feed("csv_rows", inputTensor)
    .fetch("dnn/logits/BiasAdd")
    .run()
    .get(0)
运行时,出现以下错误:

Exception in thread "main" java.lang.IllegalArgumentException: Input to reshape is a tensor with 2 values, but the requested shape has 4
 [[Node: dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _output_shapes=[[?,2]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](dnn/input_from_feature_columns/input_layer/alpha_indicator/Sum, dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape/shape)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)

我觉得如何准备输入张量存在问题,但我一直在思考如何最好地调试它

错误消息表明,某些操作中输入张量的形状与预期不同

查看您链接到的Python笔记本(特别是第8a和8c节),输入张量似乎应该是一批字符串张量,而不是单个字符串张量

您可以通过比较Scala和Python程序中张量的形状来观察这一点(Scala与Python笔记本中提供给
predict\fn
csv\u行的形状)

由此看来,您希望的是
inputTensor
成为字符串向量,而不是单个标量字符串。为此,您需要执行以下操作:

val input = Array("0.5,1,ax01,bx02")
val inputTensor = Tensor.create(input.map(x => x.getBytes("UTF-8"))

希望有帮助

是的,谢谢!输入必须作为数组传递,输出返回,在本例中,作为数组[array[Float]]。