Python 如何使Tensorflow教程中从Imagenet(classify_image.py)预先训练的inception-v3模型作为一个模块导入?
我想知道如何修改classify_image.py(从中,以便我可以从另一个python脚本导入它。我基本上希望它具有与现有相同的功能,但不是提供图像路径并在终端中打印响应,而是希望给函数一个图像路径,并让函数返回前5个结果及其概率 我还没有找到这个问题的直接解决方案,但我意识到我解决问题和搜索以前的答案是有限的,因为不幸的是我还没有学会Tensorflow的基础知识 当然,如果有另一个预先训练好的Tensorflow模型同样好,并且满足我的要求,我会很乐意使用它 问候,, 桥 更新也许我应该澄清一下: 我不想训练一个模型,只需要使用一个经过预训练的模型进行图像识别,在本例中,我可以将一个图像识别脚本作为模块导入另一个python应用程序中 我也尝试过使用来自的代码,但我也被卡住了,在这种情况下,它包括了很多手动安装,我可能在某些步骤中失败了。的好处是,我让它按照教程中的预期工作,所以我认为从那到将其用作可插拔模块的步骤不应该太大 我尝试(使用classify\u image.py)将Python 如何使Tensorflow教程中从Imagenet(classify_image.py)预先训练的inception-v3模型作为一个模块导入?,python,tensorflow,Python,Tensorflow,我想知道如何修改classify_image.py(从中,以便我可以从另一个python脚本导入它。我基本上希望它具有与现有相同的功能,但不是提供图像路径并在终端中打印响应,而是希望给函数一个图像路径,并让函数返回前5个结果及其概率 我还没有找到这个问题的直接解决方案,但我意识到我解决问题和搜索以前的答案是有限的,因为不幸的是我还没有学会Tensorflow的基础知识 当然,如果有另一个预先训练好的Tensorflow模型同样好,并且满足我的要求,我会很乐意使用它 问候,, 桥 更新也许我应该澄
if\uu name\uu='\uu main\uuu'
下的行移动到main(u)
,以便在我从另一个脚本调用它们时执行它们,但我仍然有问题。我主要是在main(u)上有问题
函数,该函数希望我向其传递一个参数,通过四处搜索,我发现从cli获取输入时使用了某种占位符。所有的标志内容似乎也与cli相关,这正是我想要避免的。我还不确定模型权重等是否正确保存到b我可以从另一个脚本中使用它。再一次,在这一点上,我只想玩一下图像分类器,并希望进一步了解它背后的机器学习。很抱歉,我缺乏这方面的基础知识
分类_image.py:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple image classification with Inception.
Run image classification with Inception trained on ImageNet 2012 Challenge data
set.
This program creates a graph from a saved GraphDef protocol buffer,
and runs inference on an input JPEG image. It outputs human readable
strings of the top 5 predictions along with their probabilities.
Change the --image_file argument to any jpg image to compute a
classification of that image.
Please see the tutorial and website for a detailed description of how
to use this script to perform image recognition.
https://tensorflow.org/tutorials/image_recognition/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os.path
import re
import sys
import tarfile
import numpy as np
from six.moves import urllib
import tensorflow as tf
FLAGS = None
# pylint: disable=line-too-long
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# pylint: enable=line-too-long
class NodeLookup(object):
"""Converts integer node ID's to human readable labels."""
def __init__(self,
label_lookup_path=None,
uid_lookup_path=None):
if not label_lookup_path:
label_lookup_path = os.path.join(
FLAGS.model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
if not uid_lookup_path:
uid_lookup_path = os.path.join(
FLAGS.model_dir, 'imagenet_synset_to_human_label_map.txt')
self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
def load(self, label_lookup_path, uid_lookup_path):
"""Loads a human readable English name for each softmax node.
Args:
label_lookup_path: string UID to integer node ID.
uid_lookup_path: string UID to human-readable string.
Returns:
dict from integer node ID to human-readable string.
"""
if not tf.gfile.Exists(uid_lookup_path):
tf.logging.fatal('File does not exist %s', uid_lookup_path)
if not tf.gfile.Exists(label_lookup_path):
tf.logging.fatal('File does not exist %s', label_lookup_path)
# Loads mapping from string UID to human-readable string
proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
uid_to_human = {}
p = re.compile(r'[n\d]*[ \S,]*')
for line in proto_as_ascii_lines:
parsed_items = p.findall(line)
uid = parsed_items[0]
human_string = parsed_items[2]
uid_to_human[uid] = human_string
# Loads mapping from string UID to integer node ID.
node_id_to_uid = {}
proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
for line in proto_as_ascii:
if line.startswith(' target_class:'):
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
target_class_string = line.split(': ')[1]
node_id_to_uid[target_class] = target_class_string[1:-2]
# Loads the final mapping of integer node ID to human-readable string
node_id_to_name = {}
for key, val in node_id_to_uid.items():
if val not in uid_to_human:
tf.logging.fatal('Failed to locate: %s', val)
name = uid_to_human[val]
node_id_to_name[key] = name
return node_id_to_name
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id]
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join(
FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
def run_inference_on_image(image):
"""Runs inference on an image.
Args:
image: Image file name.
Returns:
Nothing
"""
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(image, 'rb').read()
# Creates graph from saved GraphDef.
create_graph()
with tf.Session() as sess:
# Some useful tensors:
# 'softmax:0': A tensor containing the normalized prediction across
# 1000 labels.
# 'pool_3:0': A tensor containing the next-to-last layer containing 2048
# float description of the image.
# 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
# encoding of the image.
# Runs the softmax tensor by feeding the image_data as input to the graph.
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = NodeLookup()
top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
def maybe_download_and_extract():
"""Download and extract model tar file."""
dest_directory = FLAGS.model_dir
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def main(_):
maybe_download_and_extract()
image = (FLAGS.image_file if FLAGS.image_file else
os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
run_inference_on_image(image)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# classify_image_graph_def.pb:
# Binary representation of the GraphDef protocol buffer.
# imagenet_synset_to_human_label_map.txt:
# Map from synset ID to a human readable string.
# imagenet_2012_challenge_label_map_proto.pbtxt:
# Text representation of a protocol buffer mapping a label to synset ID.
parser.add_argument(
'--model_dir',
type=str,
default='/tmp/imagenet',
help="""\
Path to classify_image_graph_def.pb,
imagenet_synset_to_human_label_map.txt, and
imagenet_2012_challenge_label_map_proto.pbtxt.\
"""
)
parser.add_argument(
'--image_file',
type=str,
default='',
help='Absolute path to image file.'
)
parser.add_argument(
'--num_top_predictions',
type=int,
default=5,
help='Display this many predictions.'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
1) 第一个问题是关于如何返回预测值。
以下代码片段对给定图像进行了预测:
top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
您可以将结果保存在某些数据结构中并返回,而不是打印。默认情况下,如果您想将此行为更改为“设置适当的值为”--num\u top\u predictions
,则将返回5个顶部预测
2) 有关型号的信息:
它分为两部分:
但是,如果您仍然希望使用自己的数据集来训练系统,我会说,首先使用imagenet进行训练,然后使用自己的数据集训练最后一层(张量名称为“final_result”)。请找到这个 最后,我设法使用了原始问题更新中引用的SO文章中的代码。我修改了代码,添加了上述问题的答案中的
im=2*(im/255.0)-1.0
,这是在我的计算机上修复PIL的一行,外加一个将类转换为人类可读标签的函数(在github上找到),链接到下面的文件。我把它做成了一个可调用的函数,它将图像列表作为输入,并输出标签列表和预测值。如果您想使用它,您必须:
git克隆https://github.com/tensorflow/models/
您想要模型的位置import tensorflow as tf
slim = tf.contrib.slim
import PIL as pillow
from PIL import Image
#import Image
from inception_resnet_v2 import *
import numpy as np
with open('imagenet1000_clsid_to_human.txt','r') as inf:
imagenet_classes = eval(inf.read())
def get_human_readable(id):
id = id - 1
label = imagenet_classes[id]
return label
checkpoint_file = './inception_resnet_v2_2016_08_30.ckpt'
#Load the model
sess = tf.Session()
arg_scope = inception_resnet_v2_arg_scope()
input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3])
with slim.arg_scope(arg_scope):
logits, end_points = inception_resnet_v2(input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
def classify_image(sample_images):
classifications = []
for image in sample_images:
im = Image.open(image).resize((299,299))
im = np.array(im)
im = im.reshape(-1,299,299,3)
im = 2*(im/255.0)-1.0
predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_tensor: im})
#print (np.max(predict_values), np.max(logit_values))
#print (np.argmax(predict_values), np.argmax(logit_values))
label = get_human_readable(np.argmax(predict_values))
predict_value = np.max(predict_values)
classifications.append({"label":label, "predict_value":predict_value})
return classifications
在我的例子中,只需将
[-FLAGS.num\u top\u predictions://code>替换为[-5://code>
然后用目录替换其他标志,并将图像保存在其上。谢谢米林!但我在用返回替换打印语句之前很久就遇到了问题。我现在更新我的问题。