Java中向Tensorflow模型传递数据
我试图使用一个我在python中训练过的Tensorflow模型在Scala中对数据进行评分(使用TFJavaAPI)。对于模型,我使用了这个,唯一的变化是我从Java中向Tensorflow模型传递数据,java,tensorflow,Java,Tensorflow,我试图使用一个我在python中训练过的Tensorflow模型在Scala中对数据进行评分(使用TFJavaAPI)。对于模型,我使用了这个,唯一的变化是我从export\u savedmodel中删除了asText=True 我的Scala片段: val b = SavedModelBundle.load("/tensorflow/tf-estimator-tutorials/trained_models/reg-model-01/export/1531933435/", "serve"
export\u savedmodel
中删除了asText=True
我的Scala片段:
val b = SavedModelBundle.load("/tensorflow/tf-estimator-tutorials/trained_models/reg-model-01/export/1531933435/", "serve")
val s = b.session()
// output = predictor_fn({'csv_rows': ["0.5,1,ax01,bx02", "-0.5,-1,ax02,bx02"]})
val input = "0.5,1,ax01,bx02"
val inputTensor = Tensor.create(input.getBytes("UTF-8"))
val result = s.runner()
.feed("csv_rows", inputTensor)
.fetch("dnn/logits/BiasAdd")
.run()
.get(0)
运行时,出现以下错误:
Exception in thread "main" java.lang.IllegalArgumentException: Input to reshape is a tensor with 2 values, but the requested shape has 4
[[Node: dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _output_shapes=[[?,2]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](dnn/input_from_feature_columns/input_layer/alpha_indicator/Sum, dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape/shape)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)
我觉得如何准备输入张量存在问题,但我一直在思考如何最好地调试它 错误消息表明,某些操作中输入张量的形状与预期不同 查看您链接到的Python笔记本(特别是第8a和8c节),输入张量似乎应该是一批字符串张量,而不是单个字符串张量 您可以通过比较Scala和Python程序中张量的形状来观察这一点(Scala与Python笔记本中提供给
predict\fn
的csv\u行的形状)
由此看来,您希望的是inputTensor
成为字符串向量,而不是单个标量字符串。为此,您需要执行以下操作:
val input = Array("0.5,1,ax01,bx02")
val inputTensor = Tensor.create(input.map(x => x.getBytes("UTF-8"))
希望有帮助是的,谢谢!输入必须作为数组传递,输出返回,在本例中,作为数组[array[Float]]。