Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/323.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181

Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python TensorFlow已保存模型导出转换为tflite_Python_Tensorflow_Tensorflow Serving_Tensorflow Lite - Fatal编程技术网

Python TensorFlow已保存模型导出转换为tflite

Python TensorFlow已保存模型导出转换为tflite,python,tensorflow,tensorflow-serving,tensorflow-lite,Python,Tensorflow,Tensorflow Serving,Tensorflow Lite,TLDR: 我在运行时得到一个ValueError: tf.contrib.lite.TocoConverter.from_saved_model() 目标:我正在尝试将TensorFlow保存的模型转换为tflite,以便通过Firebase部署到移动设备上。我可以训练模型并输出一个保存的模型,但在使用python ToCo接口将其转换为.tflite时遇到了问题。任何帮助都将不胜感激。另外,如果有人可以评论tflite转换是否将捕获我所依赖的hub.text\u embedding\u c

TLDR: 我在运行时得到一个
ValueError:

tf.contrib.lite.TocoConverter.from_saved_model()
目标:我正在尝试将TensorFlow保存的模型转换为tflite,以便通过Firebase部署到移动设备上。我可以训练模型并输出一个保存的模型,但在使用python ToCo接口将其转换为
.tflite
时遇到了问题。任何帮助都将不胜感激。另外,如果有人可以评论tflite转换是否将捕获我所依赖的
hub.text\u embedding\u column()
输入进程。移动部署是使用原始输入文本执行此操作,还是需要单独部署该部分

问题:以下是我正在运行的代码:

投入:

train_input_fn = tf.estimator.inputs.pandas_input_fn(
    train_df, train_df["target_var"], num_epochs=None, shuffle=True
)

predict_train_input_fn = tf.estimator.inputs.pandas_input_fn(
    train_df, train_df["target_var"], shuffle=False
)

predict_test_input_fn = tf.estimator.inputs.pandas_input_fn(
    test_df, test_df["target_var"], shuffle=False)

embedded_text_feature_column = hub.text_embedding_column(
    key="text", 
    module_spec="https://tfhub.dev/google/nnlm-en-dim128/1"
)
培训和评估:

estimator = tf.estimator.DNNClassifier(
    hidden_units=[500, 100],
    feature_columns=[embedded_text_feature_column],
    n_classes=2,
    optimizer=tf.train.AdagradOptimizer(learning_rate=0.003),
    model_dir="my-model"
)

estimator.train(input_fn=train_input_fn, steps=1000)

train_eval_result = estimator.evaluate(input_fn=predict_train_input_fn)
test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn)
保存模型:

feature_spec = tf.feature_column.make_parse_example_spec([embedded_text_feature_column])

serve_input_fun = tf.estimator.export.build_parsing_serving_input_receiver_fn(
    feature_spec,
    default_batch_size=None
)

estimator.export_savedmodel(
    export_dir_base = "my-model",
    serving_input_receiver_fn = serve_input_fun,
    as_text=False,
    checkpoint_path="my-model/model.ckpt-1000",
)
转换模型:

converter = tf.contrib.lite.TocoConverter.from_saved_model("my-model/1529320265/") 
tflite_model = converter.convert()
错误

运行最后一行时,出现以下错误:

ValueError:张量输入\u示例\u张量:0未知类型tf.string

完整的跟踪是:

ValueError回溯(最近一次呼叫上次)
in()
1 converter=tf.contrib.lite.TocoConverter.from_saved_model(“my model/1529320265/”)
---->2 tflite_model=converter.convert()

/转换(self)中的media/rmn/data/projects/anaconda3/envs/monly_tf19/lib/python3.6/site-packages/tensorflow/contrib/lite/python/lite.py
307在伪数量上重新排序=自我。在伪数量上重新排序,
308更改输入范围=自身。更改输入范围,
-->309允许自定义操作=自允许自定义操作)
310返回结果
311
/toco_convert中的media/rmn/data/projects/anaconda3/envs/monily_tf19/lib/python3.6/site-packages/tensorflow/contrib/lite/python/convert.py(输入数据、输入张量、输出张量、推理类型、推理输入类型、输入格式、输出格式、量化输入统计、默认输入范围统计、删除控制依赖、跨伪数量重新排序、允许自定义操作、更改输入范围)
204其他:
205 raise VALUERROR(“张量%s未知类型%r”%”(输入_tensor.name, -->206输入(张量.dtype))
207
208 input_array=model.input_array.add()

ValueError:张量输入\u示例\u张量:0未知类型tf.string

详细信息


train_df
test_df
是由一个输入文本列和一个二进制目标变量组成的数据帧。我使用的是Python 3.6.5和TensorFlow r1.9。

这个问题在TensorFlow的
master
分支上得到了解决(在提交中)。请参考TensorFlow网站上的以下文档以从GitHub构建pip安装:。

此问题已在TensorFlow的
主分支(提交中)上解决。请参考TensorFlow网站上的以下文档以从GitHub构建pip安装: