Python 如何使Tensorflow教程中从Imagenet(classify_image.py)预先训练的inception-v3模型作为一个模块导入?

Python 如何使Tensorflow教程中从Imagenet(classify_image.py)预先训练的inception-v3模型作为一个模块导入?,python,tensorflow,Python,Tensorflow,我想知道如何修改classify_image.py(从中,以便我可以从另一个python脚本导入它。我基本上希望它具有与现有相同的功能,但不是提供图像路径并在终端中打印响应,而是希望给函数一个图像路径,并让函数返回前5个结果及其概率 我还没有找到这个问题的直接解决方案,但我意识到我解决问题和搜索以前的答案是有限的,因为不幸的是我还没有学会Tensorflow的基础知识 当然,如果有另一个预先训练好的Tensorflow模型同样好,并且满足我的要求,我会很乐意使用它 问候,, 桥 更新也许我应该澄

我想知道如何修改classify_image.py(从中,以便我可以从另一个python脚本导入它。我基本上希望它具有与现有相同的功能,但不是提供图像路径并在终端中打印响应,而是希望给函数一个图像路径,并让函数返回前5个结果及其概率

我还没有找到这个问题的直接解决方案,但我意识到我解决问题和搜索以前的答案是有限的,因为不幸的是我还没有学会Tensorflow的基础知识

当然,如果有另一个预先训练好的Tensorflow模型同样好,并且满足我的要求,我会很乐意使用它

问候,, 桥

更新也许我应该澄清一下:

我不想训练一个模型,只需要使用一个经过预训练的模型进行图像识别,在本例中,我可以将一个图像识别脚本作为模块导入另一个python应用程序中

我也尝试过使用来自的代码,但我也被卡住了,在这种情况下,它包括了很多手动安装,我可能在某些步骤中失败了。的好处是,我让它按照教程中的预期工作,所以我认为从那到将其用作可插拔模块的步骤不应该太大

我尝试(使用classify\u image.py)将
if\uu name\uu='\uu main\uuu'
下的行移动到
main(u)
,以便在我从另一个脚本调用它们时执行它们,但我仍然有问题。我主要是在
main(u)上有问题
函数,该函数希望我向其传递一个参数,通过四处搜索,我发现从cli获取输入时使用了某种占位符。所有的标志内容似乎也与cli相关,这正是我想要避免的。我还不确定模型权重等是否正确保存到b我可以从另一个脚本中使用它。再一次,在这一点上,我只想玩一下图像分类器,并希望进一步了解它背后的机器学习。很抱歉,我缺乏这方面的基础知识

分类_image.py:

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Simple image classification with Inception.
Run image classification with Inception trained on ImageNet 2012 Challenge data
set.
This program creates a graph from a saved GraphDef protocol buffer,
and runs inference on an input JPEG image. It outputs human readable
strings of the top 5 predictions along with their probabilities.
Change the --image_file argument to any jpg image to compute a
classification of that image.
Please see the tutorial and website for a detailed description of how
to use this script to perform image recognition.
https://tensorflow.org/tutorials/image_recognition/
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path
import re
import sys
import tarfile

import numpy as np
from six.moves import urllib
import tensorflow as tf

FLAGS = None

# pylint: disable=line-too-long
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# pylint: enable=line-too-long


class NodeLookup(object):
  """Converts integer node ID's to human readable labels."""

  def __init__(self,
               label_lookup_path=None,
               uid_lookup_path=None):
    if not label_lookup_path:
      label_lookup_path = os.path.join(
          FLAGS.model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
    if not uid_lookup_path:
      uid_lookup_path = os.path.join(
          FLAGS.model_dir, 'imagenet_synset_to_human_label_map.txt')
    self.node_lookup = self.load(label_lookup_path, uid_lookup_path)

  def load(self, label_lookup_path, uid_lookup_path):
    """Loads a human readable English name for each softmax node.
    Args:
      label_lookup_path: string UID to integer node ID.
      uid_lookup_path: string UID to human-readable string.
    Returns:
      dict from integer node ID to human-readable string.
    """
    if not tf.gfile.Exists(uid_lookup_path):
      tf.logging.fatal('File does not exist %s', uid_lookup_path)
    if not tf.gfile.Exists(label_lookup_path):
      tf.logging.fatal('File does not exist %s', label_lookup_path)

    # Loads mapping from string UID to human-readable string
    proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
    uid_to_human = {}
    p = re.compile(r'[n\d]*[ \S,]*')
    for line in proto_as_ascii_lines:
      parsed_items = p.findall(line)
      uid = parsed_items[0]
      human_string = parsed_items[2]
      uid_to_human[uid] = human_string

    # Loads mapping from string UID to integer node ID.
    node_id_to_uid = {}
    proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
    for line in proto_as_ascii:
      if line.startswith('  target_class:'):
        target_class = int(line.split(': ')[1])
      if line.startswith('  target_class_string:'):
        target_class_string = line.split(': ')[1]
        node_id_to_uid[target_class] = target_class_string[1:-2]

    # Loads the final mapping of integer node ID to human-readable string
    node_id_to_name = {}
    for key, val in node_id_to_uid.items():
      if val not in uid_to_human:
        tf.logging.fatal('Failed to locate: %s', val)
      name = uid_to_human[val]
      node_id_to_name[key] = name

    return node_id_to_name

  def id_to_string(self, node_id):
    if node_id not in self.node_lookup:
      return ''
    return self.node_lookup[node_id]


def create_graph():
  """Creates a graph from saved GraphDef file and returns a saver."""
  # Creates graph from saved graph_def.pb.
  with tf.gfile.FastGFile(os.path.join(
      FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')


def run_inference_on_image(image):
  """Runs inference on an image.
  Args:
    image: Image file name.
  Returns:
    Nothing
  """
  if not tf.gfile.Exists(image):
    tf.logging.fatal('File does not exist %s', image)
  image_data = tf.gfile.FastGFile(image, 'rb').read()

  # Creates graph from saved GraphDef.
  create_graph()

  with tf.Session() as sess:
    # Some useful tensors:
    # 'softmax:0': A tensor containing the normalized prediction across
    #   1000 labels.
    # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
    #   float description of the image.
    # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
    #   encoding of the image.
    # Runs the softmax tensor by feeding the image_data as input to the graph.
    softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
    predictions = sess.run(softmax_tensor,
                           {'DecodeJpeg/contents:0': image_data})
    predictions = np.squeeze(predictions)

    # Creates node ID --> English string lookup.
    node_lookup = NodeLookup()

    top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = node_lookup.id_to_string(node_id)
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))


def maybe_download_and_extract():
  """Download and extract model tar file."""
  dest_directory = FLAGS.model_dir
  if not os.path.exists(dest_directory):
    os.makedirs(dest_directory)
  filename = DATA_URL.split('/')[-1]
  filepath = os.path.join(dest_directory, filename)
  if not os.path.exists(filepath):
    def _progress(count, block_size, total_size):
      sys.stdout.write('\r>> Downloading %s %.1f%%' % (
          filename, float(count * block_size) / float(total_size) * 100.0))
      sys.stdout.flush()
    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
    print()
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  tarfile.open(filepath, 'r:gz').extractall(dest_directory)


def main(_):
  maybe_download_and_extract()
  image = (FLAGS.image_file if FLAGS.image_file else
           os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
  run_inference_on_image(image)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  # classify_image_graph_def.pb:
  #   Binary representation of the GraphDef protocol buffer.
  # imagenet_synset_to_human_label_map.txt:
  #   Map from synset ID to a human readable string.
  # imagenet_2012_challenge_label_map_proto.pbtxt:
  #   Text representation of a protocol buffer mapping a label to synset ID.
  parser.add_argument(
      '--model_dir',
      type=str,
      default='/tmp/imagenet',
      help="""\
      Path to classify_image_graph_def.pb,
      imagenet_synset_to_human_label_map.txt, and
      imagenet_2012_challenge_label_map_proto.pbtxt.\
      """
  )
  parser.add_argument(
      '--image_file',
      type=str,
      default='',
      help='Absolute path to image file.'
  )
  parser.add_argument(
      '--num_top_predictions',
      type=int,
      default=5,
      help='Display this many predictions.'
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
1) 第一个问题是关于如何返回预测值。 以下代码片段对给定图像进行了预测:

    top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = node_lookup.id_to_string(node_id)
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))
您可以将结果保存在某些数据结构中并返回,而不是打印。默认情况下,如果您想将此行为更改为“设置适当的值为”
--num\u top\u predictions
,则将返回5个顶部预测

2) 有关型号的信息: 它分为两部分:

  • 您需要像Imagenet一样拥有高质量的数据集
  • 假设您有这样高质量的数据集,那么培训inception的基础设施将需要非常强大的GPU。也有很多时间

  • 但是,如果您仍然希望使用自己的数据集来训练系统,我会说,首先使用imagenet进行训练,然后使用自己的数据集训练最后一层(张量名称为“final_result”)。请找到这个

    最后,我设法使用了原始问题更新中引用的SO文章中的代码。我修改了代码,添加了上述问题的答案中的
    im=2*(im/255.0)-1.0
    ,这是在我的计算机上修复PIL的一行,外加一个将类转换为人类可读标签的函数(在github上找到),链接到下面的文件。我把它做成了一个可调用的函数,它将图像列表作为输入,并输出标签列表和预测值。如果您想使用它,您必须:

  • 安装最新的Tensorflow版本(目前需要1.0)
  • git克隆https://github.com/tensorflow/models/
    您想要模型的位置
  • 将我前面提到的SO问题(当然需要提取)放到项目目录中
  • 将(人类可读的标签)放在项目的目录中
  • 使用SO问题中的代码,并对其进行一些修改,将其放入项目中的.py文件中:

    import tensorflow as tf
    slim = tf.contrib.slim
    import PIL as pillow
    from PIL import Image
    #import Image
    from inception_resnet_v2 import *
    import numpy as np
    
    with open('imagenet1000_clsid_to_human.txt','r') as inf:
        imagenet_classes = eval(inf.read())
    
    def get_human_readable(id):
        id = id - 1
        label = imagenet_classes[id]
    
        return label
    
    checkpoint_file = './inception_resnet_v2_2016_08_30.ckpt'
    
    #Load the model
    sess = tf.Session()
    arg_scope = inception_resnet_v2_arg_scope()
    input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3])  
    with slim.arg_scope(arg_scope):
        logits, end_points = inception_resnet_v2(input_tensor, is_training=False)
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_file)
    
    def classify_image(sample_images):
        classifications = []
        for image in sample_images:
            im = Image.open(image).resize((299,299))
            im = np.array(im)
            im = im.reshape(-1,299,299,3)
            im = 2*(im/255.0)-1.0
            predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_tensor: im})
            #print (np.max(predict_values), np.max(logit_values))
            #print (np.argmax(predict_values), np.argmax(logit_values))
            label = get_human_readable(np.argmax(predict_values))
            predict_value = np.max(predict_values)
            classifications.append({"label":label, "predict_value":predict_value})
    
        return classifications
    

  • 在我的例子中,只需将
    [-FLAGS.num\u top\u predictions://code>替换为
    [-5://code>


    然后用目录替换其他标志,并将图像保存在其上。

    谢谢米林!但我在用返回替换打印语句之前很久就遇到了问题。我现在更新我的问题。