Tensorflow 使用convert_variables_to_constants保存tf.trainable_variables()

Tensorflow 使用convert_variables_to_constants保存tf.trainable_variables(),tensorflow,keras,Tensorflow,Keras,我有一个Keras模型,我想将其转换为Tensorflow协议(例如保存的\u model.pb) 该模型来自vgg-19网络上的转移学习,其中,头部被切断,并使用完全连接的+softmax层进行训练,而vgg-19网络的其余部分被冻结 我可以在Keras中加载模型,然后使用Keras.backend.get_session()在tensorflow中运行模型,生成正确的预测: frame = preprocess(cv2.imread("path/to/img.jpg") keras_mode

我有一个Keras模型,我想将其转换为Tensorflow协议(例如
保存的\u model.pb

该模型来自vgg-19网络上的转移学习,其中,头部被切断,并使用完全连接的+softmax层进行训练,而vgg-19网络的其余部分被冻结

我可以在Keras中加载模型,然后使用
Keras.backend.get_session()
在tensorflow中运行模型,生成正确的预测:

frame = preprocess(cv2.imread("path/to/img.jpg")
keras_model = keras.models.load_model("path/to/keras/model.h5")

keras_prediction = keras_model.predict(frame)

print(keras_prediction)

with keras.backend.get_session() as sess:

    tvars = tf.trainable_variables()

    output = sess.graph.get_tensor_by_name('Softmax:0')
    input_tensor = sess.graph.get_tensor_by_name('input_1:0')

    tf_prediction = sess.run(output, {input_tensor: frame})
    print(tf_prediction) # this matches keras_prediction exactly
如果我没有包括行
tvars=tf.trainable_variables()
,那么
tf_prediction
变量是完全错误的,并且与
keras_prediction
的输出根本不匹配。事实上,输出中的所有值(具有4个概率值的单个数组)都是完全相同的(~0.25,全部加1)。这使我怀疑,如果
tf,头部的权重刚刚初始化为0。首先没有调用trainable_variables()
,这是在检查模型变量后确认的。在任何情况下,调用
tf.trainable_variables()
都会导致tensorflow预测正确

问题是,当我尝试保存此模型时,来自
tf.trainable_variables()
的变量实际上没有保存到
.pb
文件:

with keras.backend.get_session() as sess:
    tvars = tf.trainable_variables()

    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), ['Softmax'])
    graph_io.write_graph(constant_graph, './', 'saved_model.pb', as_text=False)
我想问的是,如何将Keras模型保存为完整的
tf.training_variables()
的Tensorflow协议


非常感谢

因此,冻结图形中的变量(转换为常量)的方法应该是可行的,但不是必需的,而且比其他方法更复杂。(更多信息请参见下文)。如果您出于某种原因(例如导出到移动设备)希望图形冻结,我需要更多详细信息来帮助调试,因为我不确定Keras在幕后对您的图形做了什么。但是,如果您想稍后保存并加载一个图形,我可以解释如何进行(尽管不能保证Keras所做的任何事情都不会把它搞砸…,我很乐意帮助调试)

所以这里实际上有两种格式。一种是用于检查点的
GraphDef
,因为它不包含关于输入和输出的元数据。另一个是包含元数据和图形定义的
MetaGraphDef
,元数据可用于预测和运行
ModelServer
(来自tensorflow/serving)

在这两种情况下,您需要做的不仅仅是调用
graph\u io.write\u graph
,因为变量通常存储在graphdef之外

这两种用例都有包装器库。主要用于保存和恢复检查点

但是,由于您需要预测,我建议使用来构建SavedModel二进制文件。我在下面提供了一些锅炉板:

from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY as DEFAULT_SIG_DEF
builder = tf.saved_model.builder.SavedModelBuilder('./mymodel')
with keras.backend.get_session() as sess:
  output = sess.graph.get_tensor_by_name('Softmax:0')
  input_tensor = sess.graph.get_tensor_by_name('input_1:0')
  sig_def = tf.saved_model.signature_def_utils.predict_signature_def(
    {'input': input_tensor},
    {'output': output}
  )
  builder.add_meta_graph_and_variables(
      sess, tf.saved_model.tag_constants.SERVING,
      signature_def_map={
        DEFAULT_SIG_DEF: sig_def
      }
  )
builder.save()
运行此代码后,您应该有一个
mymodel/saved_model.pb
文件以及一个目录
mymodel/variables/
,其中protobufs对应于变量值

然后,要再次加载模型,只需使用
tf.saved\u model.loader

# Does Keras give you the ability to start with a fresh graph?
# If not you'll need to do this in a separate program to avoid
# conflicts with the old default graph
with tf.Session(graph=tf.Graph()):
  meta_graph_def = tf.saved_model.loader.load(
      sess, 
      tf.saved_model.tag_constants.SERVING,
      './mymodel'
  )
  # From this point variables and graph structure are restored

  sig_def = meta_graph_def.signature_def[DEFAULT_SIG_DEF]
  print(sess.run(sig_def.outputs['output'], feed_dict={sig_def.inputs['input']: frame}))
显然,通过tensorflow/serving或Cloud ML引擎,这段代码可以实现更有效的预测,但这应该可以实现。 有可能Keras在幕后做了一些事情,这也会干扰这个过程,如果是这样的话,我们想听听(我想确保Keras用户也能够冻结图形,因此,如果您想给我发送完整代码的要点或其他信息,也许我可以找到了解Keras的人来帮助我调试。)


编辑:你可以在这里找到一个端到端的例子:

以下方法行得通吗?通过
keras_model.layers
,对每一个集合
layer.trainable=False
。然后,在一个单独的
h5
文件中将这些层编译成一个新的keras模型,并以此作为转换为Tensorflow的基础。只要尝试一下……就可以了似乎根本不起作用。尝试从代码中复制错误,但到目前为止没有成功。即使没有行
tf.trainable_variables()
,我看到
tf\u prediction
keras\u prediction
匹配。另外,最后的
densite
层在保存的
.pb
文件中。你能提供显示如何构建和保存模型的代码,以及你使用的keras和tf版本吗?你能试着运行:
saver=tf.train.saver吗(tf.global_variables())saver.save(sess,/tmp/model/my_model')
保存您的模型这是完美的,它解决了我遇到的问题!非常感谢!