Python Keras load_model()仅在Flask请求上下文中失败

Python Keras load_model()仅在Flask请求上下文中失败,python,flask,keras,Python,Flask,Keras,我有一个正在工作的Keras模型,它可以在repl中做出很好的预测,但无法在Flask应用程序中加载。这是一个Keras bug还是我缺少一些基本的Python变量范围理解 烧瓶应用程序: # app.py import os import sys from train import train_model from predict import predict_image import requests from flask import Flask, request, jsonify s

我有一个正在工作的Keras模型,它可以在repl中做出很好的预测,但无法在Flask应用程序中加载。这是一个Keras bug还是我缺少一些基本的Python变量范围理解

烧瓶应用程序:

# app.py

import os
import sys
from train import train_model
from predict import predict_image
import requests
from flask import Flask, request, jsonify

sys.dont_write_bytecode = True

app = Flask(__name__)

@app.route('/')
def hello():
    return jsonify({'status': 'Service available.'}), 200

@app.route('/predict') # ?company=<company_id>&image_url=<image_url>
def predict_route():
    company_id = request.args.get('company')
    image_url = str(request.args.get('image_url'))
    result = predict_image(company_id, url=image_url)
    return jsonify(result), 200
它通过repl工作:

import predict
predict.predict_image(...)
# model loads and returns expected result
但如果我通过烧瓶应用程序尝试,我会

# curl ml:5000/predict?company=1&image_url=<image_url>

[top of traceback omitted for brevity]
  File "/code/app.py", line 22, in predict_route
    result = predict_image(company_id, url=image_url)
  File "/code/predict.py", line 34, in predict_image
    model_graph, class_names = load_classification_model(company_id)
  File "/code/predict.py", line 21, in load_classification_model
    model = load_model(model_path)
  File "/usr/local/lib/python2.7/site-packages/keras/models.py", line 242, in load_model
    topology.load_weights_from_hdf5_group(f['model_weights'], model.layers)
  File "/usr/local/lib/python2.7/site-packages/keras/engine/topology.py", line 3095, in load_weights_from_hdf5_group
    K.batch_set_value(weight_value_tuples)
  File "/usr/local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 2193, in batch_set_value
    get_session().run(assign_ops, feed_dict=feed_dict)
  File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 895, in run
    run_metadata_ptr)
  File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1071, in _run
    + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(2048, 64), dtype=float32) is not an element of this graph.
#curl ml:5000/predict?company=1&image\u url=
[为简洁起见,省略了回溯的顶部]
文件“/code/app.py”,第22行,路径中
结果=预测图片(公司id,url=图片url)
文件“/code/predict.py”,第34行,在predict\u图像中
模型图,类别名称=负载分类模型(公司id)
文件“/code/predict.py”,第21行,在负荷分类模型中
模型=加载模型(模型路径)
文件“/usr/local/lib/python2.7/site packages/keras/models.py”,第242行,在load_模型中
拓扑。从组(f['model\u weights'],model.layers)加载权重
文件“/usr/local/lib/python2.7/site-packages/keras/engine/topology.py”,第3095行,从hdf5组加载权重
K.批量设置值(权重值元组)
文件“/usr/local/lib/python2.7/site packages/keras/backend/tensorflow\u backend.py”,第2193行,在批处理设置值中
获取会话()
文件“/usr/local/lib/python2.7/site packages/tensorflow/python/client/session.py”,第895行,正在运行
运行_元数据_ptr)
文件“/usr/local/lib/python2.7/site packages/tensorflow/python/client/session.py”,第1071行,在运行中
+e.args[0])
TypeError:无法将提要索引键解释为张量:张量张量(“占位符:0”,shape=(2048,64),dtype=float32)不是此图形的元素。

使用tensorflow图交叉线程时,Keras中似乎存在一个缺陷。要解决此问题,请执行以下操作:

    # Right after loading or constructing your model, save the TensorFlow graph:
    import tensorflow as tf
    graph = tf.get_default_graph()

    # In the other thread (or perhaps in an asynchronous event handler), do:
    global graph
    with graph.as_default():
        (...
        do
        inference
        here...)

尝试在烧瓶中设置
debug=False

在多次失败的tensorflow保存/加载尝试后为我工作

(感谢shafy@github)

对我来说,在我的烧瓶应用程序的底部,它看起来是这样的:

if __name__ == '__main__':
    app.run(debug=False)
也看到

在REPL中,您将什么传递给
predict.predict_image(…)
呢?与Flask传递的参数相同,我已经打印了它们以供检查。错误回溯也表明了不同类型的问题,所以我认为这不太可能。感谢您的意见。在Flask中,我甚至无法加载图形,因此在加载后无法从tf获取默认图形。这似乎是合理的,可能是在正确的轨道上,但我觉得这更像是Keras从其保存的模型格式加载模型,在Flask中发生的情况与REPL不同。我认为在初始化Flask
app
时最好加载Keras模型,而不是在调用API时一次又一次地加载模型,这样会非常慢。如果以这种方式加载模型,则可以获得图形,并且在调用API时,只需使用初始化的模型来预测并用
和graph将其包装。默认值为():
if __name__ == '__main__':
    app.run(debug=False)