由于eval()函数,Tensorflow推断变得越来越慢
所以我有一个冻结的tensorflow模型,可以用来对图像进行分类。当我尝试使用这个模型一个接一个地推断图像时,模型运行得越来越慢。我搜索并发现eval()函数可能导致的问题,该函数将不断向图形中添加新节点,从而减慢过程 以下是我的代码的关键部分:由于eval()函数,Tensorflow推断变得越来越慢,tensorflow,inference,Tensorflow,Inference,所以我有一个冻结的tensorflow模型,可以用来对图像进行分类。当我尝试使用这个模型一个接一个地推断图像时,模型运行得越来越慢。我搜索并发现eval()函数可能导致的问题,该函数将不断向图形中添加新节点,从而减慢过程 以下是我的代码的关键部分: with open('/tmp/frozen_resnet_v1_50.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read())
with open('/tmp/frozen_resnet_v1_50.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
sess1 = tf.Session()
sess = tf.Session()
for root, dirs, files in os.walk(file_path):
for f in files:
# Read image one by one and preprocess
img = cv2.imread(os.path.join(root, f))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR 2 RGB
img = image_preprocessing_fn(img, _IMAGE_HEIGHT, _IMAGE_WIDTH) # This function contains tf functions
img = img.eval(session=sess1)
img = np.reshape(img, [-1, _IMAGE_HEIGHT, _IMAGE_WIDTH, _IMAGE_CHANNEL]) # the input shape is 4 dimension
# Feed image to model
data = sess.graph.get_tensor_by_name('input:0')
predict = sess.graph.get_tensor_by_name('resnet_v1_50/predictions/Softmax:0')
out = sess.run(predict, feed_dict={data: img})
indices = np.argmax(out, 1)
print('Current image name: %s, predict result: %s' % (f, indices))
sess1.close()
sess.close()
PS:我用“sess1”做预处理,我想这可能不合适。希望有人能告诉我正确的方法,提前谢谢。没有人回答……这是我的解决方案,很有效
with open('/tmp/frozen_resnet_v1_50.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
x = tf.placeholder(tf.uint8, shape=[None, None, 3])
y = image_preprocessing_fn(x, _IMAGE_HEIGHT, _IMAGE_WIDTH)
sess = tf.Session()
data = sess.graph.get_tensor_by_name('input:0')
predict = sess.graph.get_tensor_by_name('resnet_v1_50/predictions/Softmax:0')
for root, dirs, files in os.walk(file_path):
for f in files:
img = cv2.imread(os.path.join(root, f))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR 2 RGB
img = sess.run(y, feed_dict={x: img})
img = np.reshape(img, [-1, _IMAGE_HEIGHT, _IMAGE_WIDTH, _IMAGE_CHANNEL])
out = sess.run(predict, feed_dict={data: img})
indices = np.argmax(out, 1)
print('Current image name: %s, predict result: %s' % (f, out))
sess.close()