Python Tensorflow:从图形文件(.pb文件)中获取预测

Python Tensorflow:从图形文件(.pb文件)中获取预测,python,tensorflow,tensorboard,Python,Tensorflow,Tensorboard,我正在使用一个图形文件(pb文件),这个Tensorflow模型的目的是提供对特定图像的预测 我已经开发了一个加载图形文件的代码,但我不能统计会话。 可用文件包括:- 培训\u模型\u保存的\u模型.pb 变数 培训\模型\变量\变量。数据-00000-of-00001 培训\模型\变量\变量.index 输出是错误的,包含一个模型层的大列表。在这种情况下,我能做什么,非常感谢任何帮助 这是我用来加载/运行模型的代码 import tensorflow as tf import sys

我正在使用一个图形文件(pb文件),这个Tensorflow模型的目的是提供对特定图像的预测

我已经开发了一个加载图形文件的代码,但我不能统计会话。 可用文件包括:-

  • 培训\u模型\u保存的\u模型.pb
  • 变数
    • 培训\模型\变量\变量。数据-00000-of-00001
    • 培训\模型\变量\变量.index
输出是错误的,包含一个模型层的大列表。在这种情况下,我能做什么,非常感谢任何帮助

这是我用来加载/运行模型的代码

import tensorflow as tf
import sys
import os



import matplotlib.image as mpimg
import matplotlib.pyplot as plt


from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat
from tensorflow.python.platform import gfile

export_dir = os.path.join("./", "variables/")
filename = "imgpsh_fullsize.jpeg"
raw_image_data = mpimg.imread(filename)

g = tf.Graph()
with tf.Session(graph=g) as sess:
   model_filename ='training_model_saved_model.pb'
   with gfile.FastGFile(model_filename, 'rb') as f:

        data = compat.as_bytes(f.read())
        sm = saved_model_pb2.SavedModel()
        sm.ParseFromString(data)
        #print(sm)
        if 1 != len(sm.meta_graphs):
                print('More than one graph found. Not sure which to write')
                sys.exit(1)

        image_input= tf.import_graph_def(sm.meta_graphs[0].graph_def,name='',return_elements=["input"])
        #print(image_input)
        #saver =  tf.train.Saver()
        saver = tf.train.import_meta_graph(sm.meta_graphs[0].graph_def)
        '''
        print(image_input)

        x = g.get_tensor_by_name("input:0")

        print(x)
        '''
        saver.restore(sess,model_filename)

        predictions = sess.run(feed_dict={image: raw_image_data})
        print('###################################################')
        print(predictions)
存在的错误是

Traceback (most recent call last):
  File "model_Input-get.py", line 35, in <module>
    saver = tf.train.import_meta_graph(sm.meta_graphs[0].graph_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1691, in import_meta_graph
    meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/meta_graph.py", line 553, in read_meta_graph_file
    if not file_io.file_exists(filename):
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/lib/io/file_io.py", line 252, in file_exists
    pywrap_tensorflow.FileExists(compat.as_bytes(filename), status)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/compat.py", line 65, in as_bytes
    (bytes_or_text,))
TypeError: Expected binary or unicode string, got node {
  name: "input"
  op: "Placeholder"
  attr {
    key: "_output_shapes"
    value {
      list {
        shape {
          dim {
            size: -1
          }
        }
      }
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_STRING
    }
  }
回溯(最近一次呼叫最后一次):
文件“model_Input-get.py”,第35行,在
saver=tf.train.import_meta_图(sm.meta_图[0].graph_def)
文件“/usr/local/lib/python2.7/dist packages/tensorflow/python/training/saver.py”,第1691行,在导入元图中
meta_graph_def=meta_graph.read_meta_graph_文件(meta_graph_或_文件)
文件“/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/meta_-graph.py”,第553行,在read_-meta_-graph_文件中
如果文件不存在,io.file存在(文件名):
文件“/usr/local/lib/python2.7/dist packages/tensorflow/python/lib/io/File_io.py”,第252行,文件_中存在
pywrap_tensorflow.FileExists(compat.as_字节(文件名),状态)
文件“/usr/local/lib/python2.7/dist packages/tensorflow/python/util/compat.py”,第65行,以字节为单位
(字节或文本)
TypeError:应为二进制或unicode字符串,已获取节点{
名称:“输入”
op:“占位符”
属性{
键:“\u输出\u形状”
价值观{
名单{
形状{
暗淡的{
尺寸:-1
}
}
}
}
}
属性{
键:“数据类型”
价值观{
类型:DT_字符串
}
}

您似乎将TensorFlow服务的SavedModel格式与常规TensorFlow导出/还原功能混合在一起

这是TensorFlow代码库中一个特别令人困惑的部分,因为这种格式在首次出现时没有很好的文档记录,并且没有很多示例显示何时使用这种格式与原始格式相比

我的建议是:

  • 切换到TF Serving并继续使用SavedModel格式,或
  • 坚持原始导出/恢复模型格式

  • 我在还原GraphFile时感到迷茫。我使用了SavedModel,因为当我尝试使用'graph_def=tf.GraphDef()graph_def.ParseFromString(f.read())g_in=tf.import_graph_def(graph_def)时“我得到protobuf.message.decodeError你建议在代码中编辑什么,你能提供更多细节吗?你说切换到TF服务是什么意思,我听说bazel在这里如何应用它