Java o=savedModel.metaGraphDef().getSignatureDefMap().get(“默认服务”)
Java o=savedModel.metaGraphDef().getSignatureDefMap().get(“默认服务”),java,tensorflow,Java,Tensorflow,o=savedModel.metaGraphDef().getSignatureDefMap().get(“默认服务”)
o=savedModel.metaGraphDef().getSignatureDefMap().get(“默认服务”)对不起,我刚刚意识到我用Kotlin写了原始答案:)
try (SavedModelBundle model = SavedModelBundle.load("./house_price_median_income", "serve")) {
try (Session session = model.session()) {
Session.Runner runner = session.runner();
float[][] in = new float[][]{ {2.1518f} } ;
Tensor<?> jack = Tensor.create(in);
runner.feed("serving_default_layer1_input", jack);
float[][] probabilities = runner.fetch("StatefulPartitionedCall").run().get(0).copyTo(new float[1][1]);
for (int i = 0; i < probabilities.length; ++i) {
System.out.println(String.format("-- Input #%d", i));
for (int j = 0; j < probabilities[i].length; ++j) {
System.out.println(String.format("Class %d - %f", i, probabilities[i][j]));
}
}
}
}
Graph graph = model.graph();
Iterator<Operation> itr = graph.operations();
while (itr.hasNext()) {
GraphOperation e = (GraphOperation)itr.next();
System.out.println(e);
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.Session.Run;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.GraphOperation;
import org.tensorflow.proto.framework.SignatureDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.tensorflow.proto.framework.MetaGraphDef;
import java.util.Map;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.tools.Shape;
import java.nio.FloatBuffer;
import org.tensorflow.tools.buffer.DataBuffers;
import org.tensorflow.tools.ndarray.FloatNdArray;
import org.tensorflow.tools.ndarray.StdArrays;
import org.tensorflow.proto.framework.TensorInfo;
public class v2tensor {
public static void main(String[] args) {
try (SavedModelBundle savedModel = SavedModelBundle.load("./house_price_median_income", "serve")) {
SignatureDef modelInfo = savedModel.metaGraphDef().getSignatureDefMap().get("serving_default");
TensorInfo input1 = null;
TensorInfo output1 = null;
Map<String, TensorInfo> inputs = modelInfo.getInputsMap();
for(Map.Entry<String, TensorInfo> input : inputs.entrySet()) {
if (input1 == null) {
input1 = input.getValue();
System.out.println(input1.getName());
}
System.out.println(input);
}
Map<String, TensorInfo> outputs = modelInfo.getOutputsMap();
for(Map.Entry<String, TensorInfo> output : outputs.entrySet()) {
if (output1 == null) {
output1=output.getValue();
}
System.out.println(output);
}
try (Session session = savedModel.session()) {
Session.Runner runner = session.runner();
FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{ { 2.1518f } } );
try (Tensor<TFloat32> jack = TFloat32.tensorOf(matrix) ) {
runner.feed(input1.getName(), jack);
try ( Tensor<TFloat32> rezz = runner.fetch(output1.getName()).run().get(0).expect(TFloat32.DTYPE) ) {
TFloat32 data = rezz.data();
data.scalars().forEachIndexed((i, s) -> {
System.out.println(s.getFloat());
} );
}
}
}
} catch (TensorFlowException ex) {
ex.printStackTrace();
}
}
}