Python TF服务器提供的导出Keras分类模型给出:期望arg[0]为float,但提供了string

Python TF服务器提供的导出Keras分类模型给出:期望arg[0]为float,但提供了string,python,keras,tensorflow-serving,Python,Keras,Tensorflow Serving,我已经在Keras(本文所述的Keras和TF的最新版本)中训练了一个分类模型,该模型在输入和输出方面与CIFAR10类似。为了提供此模型,我使用以下代码将其导出到分类模型(请参见类型): def keras_model_to_tf_serve(saved_keras_model, local_version_dir, type='classification',

我已经在Keras(本文所述的Keras和TF的最新版本)中训练了一个分类模型,该模型在输入和输出方面与CIFAR10类似。为了提供此模型,我使用以下代码将其导出到分类模型(请参见类型):

def keras_model_to_tf_serve(saved_keras_model,
                        local_version_dir,
                        type='classification',
                        save_model_version=1):

sess = tf.Session()
K.set_session(sess)
K.set_learning_phase(0)

old_model = load_model(saved_keras_model)
config = old_model.get_config()
weights = old_model.get_weights()

new_model = Sequential.from_config(config)
new_model.set_weights(weights)

classification_inputs = utils.build_tensor_info(new_model.input)
classification_outputs_classes = utils.build_tensor_info(new_model.output)

# The classification signature
classification_signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs={signature_constants.CLASSIFY_INPUTS: classification_inputs},
    outputs={
        signature_constants.CLASSIFY_OUTPUT_CLASSES:
            classification_outputs_classes
    },
    method_name=signature_constants.CLASSIFY_METHOD_NAME)
#print(classification_signature)
# The prediction signature
tensor_info_x = utils.build_tensor_info(new_model.input)
tensor_info_y = utils.build_tensor_info(new_model.output)
prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs={'inputs': tensor_info_x},
    outputs={'outputs': tensor_info_y},
    method_name=signature_constants.PREDICT_METHOD_NAME)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

print(prediction_signature)
save_model_dir = os.path.join(local_version_dir,str(save_model_version))
if os.path.exists(save_model_dir) and os.path.isdir(save_model_dir):
    shutil.rmtree(save_model_dir)

builder = saved_model_builder.SavedModelBuilder(save_model_dir)
with K.get_session() as sess:
    if type == 'classification':
        builder.add_meta_graph_and_variables(
            sess, [tag_constants.SERVING],
            signature_def_map={
                signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                    classification_signature,
            },
            clear_devices=True, legacy_init_op=legacy_init_op)
    elif type == 'prediction':
        builder.add_meta_graph_and_variables(
            sess, [tag_constants.SERVING],
            signature_def_map={
                # Uncomment the first two lines below and comment out the subsequent four to reset.
                # signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                #    classification_signature,
                'predict_results':
                    prediction_signature,
            },
            clear_devices=True, legacy_init_op=legacy_init_op)
    else:
        builder.add_meta_graph_and_variables(
            sess, [tag_constants.SERVING],
            signature_def_map={
                # Uncomment the first two lines below and comment out the subsequent four to reset.
                # signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                #    classification_signature,
                'predict_results':
                    prediction_signature,
                signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                    classification_signature
            },
            clear_devices=True, legacy_init_op=legacy_init_op)
    builder.save()
这将很好地导出,并使用保存的\u model\u cli获得以下输出:

saved_model_cli show --dir /develop/1/ --tag_set serve -- 
signature_def serving_default
The given SavedModel SignatureDef contains the following input(s):
  inputs['inputs'] tensor_info:
     dtype: DT_FLOAT
     shape: (-1, 32, 32, 3)
     name: conv2d_1_input_1:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['classes'] tensor_info:
     dtype: DT_FLOAT
     shape: (-1, 10)
     name: activation_6_1/Softmax:0
Method name is: tensorflow/serving/classify
因此,模型期望得到形状为(-1,32,32,3)的DT_浮点。由于这是一个分类模型(由于某种原因,它在如何使用上与预测模型/非常/不同),我使用了@sdcbr代码()并做了一些细微的修改:

import tensorflow as tf
import numpy as np
from tensorflow_serving.apis import classification_pb2, input_pb2
from grpc.beta import implementations
from tensorflow_serving.apis import prediction_service_pb2

image = np.random.rand(32,32,3)

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

request = classification_pb2.ClassificationRequest()
request.model_spec.name = 'model'
request.model_spec.signature_name = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY

image = image.flatten().tolist()
image = [float(x) for x in image]
example = tf.train.Example(features=tf.train.Features(feature={'image': _float_feature(image)}))

inp = input_pb2.Input()
inp.example_list.examples.extend([example])

request.input.CopyFrom(inp)

channel = implementations.insecure_channel('localhost', 5005)
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
response = stub.Classify(request, 10.0)
其中TF Serve在我的机器上本地运行,在端口上,并且在启动时被赋予spec_名称。就我所见,这应该是可行的,但是当我运行它时,我得到了以下错误(为了简洁起见,在这里缩短):


grpc.\u channel.\u Rendezvous:尝试使用
prediction\u service\u pb2\u grpc.PredictionService-tub(频道)
而不是
prediction\u service\u pb2.beta\u create\u PredictionService\u stub(频道)
。显然,这是最近从测试版转移过来的。您可以参考示例

您是否可以尝试使用
prediction\u service\u pb2\u grpc.PredictionServiceStub(频道)
而不是
prediction\u service\u pb2.beta\u create\u PredictionService\u stub(频道)
?显然,这是最近从beta.interest迁移过来的。让我来试一试。你可以通过查看你暗示的链接和源代码来了解它的工作原理。留下答案。
grpc._channel._Rendezvous: <_Rendezvous of RPC that terminated with:
   status = StatusCode.INVALID_ARGUMENT
   details = "Expects arg[0] to be float but string is provided"
   debug_error_string = " 
      {"created":"@1533046733.211573219","description":"Error received from peer","file":"src/core/lib/surface/call.cc","file_line":1083,"grpc_message":"Expects arg[0] to be float but string is provided","grpc_status":3}"