Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
如何在Java中使用TensorFlow LinearClassifier_Java_Tensorflow - Fatal编程技术网

如何在Java中使用TensorFlow LinearClassifier

如何在Java中使用TensorFlow LinearClassifier,java,tensorflow,Java,Tensorflow,在Python中,我训练了TensorFlow LinearClassifier,并将其保存为: model = tf.contrib.learn.LinearClassifier(feature_columns=columns) model.fit(input_fn=train_input_fn, steps=100) model.export_savedmodel(export_dir, parsing_serving_input_fn) 通过使用TensorFlow Java API,我

在Python中,我训练了TensorFlow LinearClassifier,并将其保存为:

model = tf.contrib.learn.LinearClassifier(feature_columns=columns)
model.fit(input_fn=train_input_fn, steps=100)
model.export_savedmodel(export_dir, parsing_serving_input_fn)
通过使用TensorFlow Java API,我能够使用以下方式在Java中加载此模型:

model = SavedModelBundle.load(export_dir, "serve");
看起来我应该能够使用类似于

model.session().runner().feed(???, ???).fetch(???, ???).run()

但是,我应该向图中输入/获取哪些变量名/数据来提供它的特性和获取类的概率呢?据我所知,Java文档缺少这些信息。

要馈送的节点的名称取决于
解析\u服务\u输入\u fn
所做的事情,特别是它们应该是
解析\u服务\u输入\u fn
返回的
张量
对象的名称。要获取的节点的名称将取决于您所预测的内容(如果使用Python中的模型,
model.predict()
的参数)

也就是说,TensorFlow保存的模型格式确实包含模型的“签名”(即,可以馈送或获取的所有张量的名称)作为可以提供提示的元数据

从Python中,您可以加载保存的模型,并使用以下方法列出其签名:

with tf.Session() as sess:
  md = tf.saved_model.loader.load(sess, ['serve'], export_dir)
  sig = md.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
  print(sig)
inputs {
  key: "inputs"
  value {
    name: "input_example_tensor:0"
    dtype: DT_STRING
    tensor_shape {
      dim {
        size: -1
      }
    }
  }
}
outputs {
  key: "scores"
  value {
    name: "linear/binary_logistic_head/predictions/probabilities:0"
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: -1
      }
      dim {
        size: 2
      }
    }
  }
}
method_name: "tensorflow/serving/classify"
它将打印如下内容:

with tf.Session() as sess:
  md = tf.saved_model.loader.load(sess, ['serve'], export_dir)
  sig = md.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
  print(sig)
inputs {
  key: "inputs"
  value {
    name: "input_example_tensor:0"
    dtype: DT_STRING
    tensor_shape {
      dim {
        size: -1
      }
    }
  }
}
outputs {
  key: "scores"
  value {
    name: "linear/binary_logistic_head/predictions/probabilities:0"
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: -1
      }
      dim {
        size: 2
      }
    }
  }
}
method_name: "tensorflow/serving/classify"
建议您在Java中要做的是:

Tensor t = /* Tensor object to be fed */
model.session().runner().feed("input_example_tensor", t).fetch("linear/binary_logistic_head/predictions/probabilities").run()
如果您的程序包含为TensorFlow协议缓冲区(打包在中)生成的Java代码,则您也可以在Java中提取这些信息,使用如下内容:

// Same as tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
// in Python. Perhaps this should be an exported constant in TensorFlow's Java API.
final String DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default"; 

final SignatureDef sig =
      MetaGraphDef.parseFrom(model.metaGraphDef())
          .getSignatureDefOrThrow(DEFAULT_SERVING_SIGNATURE_DEF_KEY);
您必须添加:

import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
由于JavaAPI和保存的模型格式有些新,因此文档中有很大的改进空间


希望有帮助。

要提供的节点的名称将取决于
解析\u服务\u输入\u fn
所做的工作,特别是它们应该是
解析\u服务\u输入\u fn
返回的
张量
对象的名称。要获取的节点的名称将取决于您所预测的内容(如果使用Python中的模型,
model.predict()
的参数)

也就是说,TensorFlow保存的模型格式确实包含模型的“签名”(即,可以馈送或获取的所有张量的名称)作为可以提供提示的元数据

从Python中,您可以加载保存的模型,并使用以下方法列出其签名:

with tf.Session() as sess:
  md = tf.saved_model.loader.load(sess, ['serve'], export_dir)
  sig = md.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
  print(sig)
inputs {
  key: "inputs"
  value {
    name: "input_example_tensor:0"
    dtype: DT_STRING
    tensor_shape {
      dim {
        size: -1
      }
    }
  }
}
outputs {
  key: "scores"
  value {
    name: "linear/binary_logistic_head/predictions/probabilities:0"
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: -1
      }
      dim {
        size: 2
      }
    }
  }
}
method_name: "tensorflow/serving/classify"
它将打印如下内容:

with tf.Session() as sess:
  md = tf.saved_model.loader.load(sess, ['serve'], export_dir)
  sig = md.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
  print(sig)
inputs {
  key: "inputs"
  value {
    name: "input_example_tensor:0"
    dtype: DT_STRING
    tensor_shape {
      dim {
        size: -1
      }
    }
  }
}
outputs {
  key: "scores"
  value {
    name: "linear/binary_logistic_head/predictions/probabilities:0"
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: -1
      }
      dim {
        size: 2
      }
    }
  }
}
method_name: "tensorflow/serving/classify"
建议您在Java中要做的是:

Tensor t = /* Tensor object to be fed */
model.session().runner().feed("input_example_tensor", t).fetch("linear/binary_logistic_head/predictions/probabilities").run()
如果您的程序包含为TensorFlow协议缓冲区(打包在中)生成的Java代码,则您也可以在Java中提取这些信息,使用如下内容:

// Same as tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
// in Python. Perhaps this should be an exported constant in TensorFlow's Java API.
final String DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default"; 

final SignatureDef sig =
      MetaGraphDef.parseFrom(model.metaGraphDef())
          .getSignatureDefOrThrow(DEFAULT_SERVING_SIGNATURE_DEF_KEY);
您必须添加:

import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
由于JavaAPI和保存的模型格式有些新,因此文档中有很大的改进空间


希望这能有所帮助。

谢谢你的回答!这看起来很有希望。但是,我必须为输入张量提供什么?例如,考虑导出:模型导出与您提供的签名相同的输入(输入,dType:dtyScript),但我需要以某种方式提供这个模型4个数字。现在我理解,模型需要一个序列化的示例协议缓冲区,但是在这个时候(1)协议缓冲区在java和(2)中不可用。还不支持使用数据类型字符串(序列化示例需要使用该字符串)创建张量:(仅供参考:Java中的maven artifact()数据类型中提供了协议缓冲区。标量(即单个字符串)支持字符串张量,但还不支持多维数组()希望这有帮助。再次感谢您的反馈。很高兴知道protos也可以在Java中使用。关于带字符串的张量:我需要输入字符串向量来输入\u示例\u张量,对吗?所以字符串标量目前没有帮助。或者我可以解决这个问题吗?谢谢您的回答!这看起来很有希望。但是,我该怎么办例如,导出:模型导出与您提供的签名相同的输入(输入,dType:dtyStand),但我需要以某种方式提供这个模型4个数字。现在我理解,模型需要一个序列化的示例协议缓冲区,但此时(1)Java中没有可用的协议缓冲区,(2)还不支持使用数据类型字符串创建张量(序列化示例需要此张量)。:(仅供参考:Java中的maven artifact()数据类型中有可用的协议缓冲区。标量(即单个字符串)支持字符串张量,但还不支持多维数组()希望这有帮助。再次感谢您的反馈。很高兴知道protos也可以在Java中使用。关于带字符串的张量:我需要输入字符串向量来输入\u示例\u张量,对吗?所以字符串标量目前没有帮助。或者我可以解决这个问题吗?