由于eval()函数,Tensorflow推断变得越来越慢

由于eval()函数,Tensorflow推断变得越来越慢,tensorflow,inference,Tensorflow,Inference,所以我有一个冻结的tensorflow模型,可以用来对图像进行分类。当我尝试使用这个模型一个接一个地推断图像时,模型运行得越来越慢。我搜索并发现eval()函数可能导致的问题,该函数将不断向图形中添加新节点,从而减慢过程 以下是我的代码的关键部分: with open('/tmp/frozen_resnet_v1_50.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read())

所以我有一个冻结的tensorflow模型,可以用来对图像进行分类。当我尝试使用这个模型一个接一个地推断图像时,模型运行得越来越慢。我搜索并发现eval()函数可能导致的问题,该函数将不断向图形中添加新节点,从而减慢过程

以下是我的代码的关键部分:

with open('/tmp/frozen_resnet_v1_50.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

    sess1 = tf.Session()
    sess = tf.Session()

    for root, dirs, files in os.walk(file_path):
        for f in files:
            # Read image one by one and preprocess
            img = cv2.imread(os.path.join(root, f))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # BGR 2 RGB

            img = image_preprocessing_fn(img, _IMAGE_HEIGHT, _IMAGE_WIDTH)  # This function contains tf functions
            img = img.eval(session=sess1)

            img = np.reshape(img, [-1, _IMAGE_HEIGHT, _IMAGE_WIDTH, _IMAGE_CHANNEL])    # the input shape is 4 dimension

            # Feed image to model
            data = sess.graph.get_tensor_by_name('input:0')
            predict = sess.graph.get_tensor_by_name('resnet_v1_50/predictions/Softmax:0')

            out = sess.run(predict, feed_dict={data: img})
            indices = np.argmax(out, 1)

            print('Current image name: %s, predict result: %s' % (f, indices))

    sess1.close()
    sess.close()

PS:我用“sess1”做预处理,我想这可能不合适。希望有人能告诉我正确的方法,提前谢谢。

没有人回答……这是我的解决方案,很有效

with open('/tmp/frozen_resnet_v1_50.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

    x = tf.placeholder(tf.uint8, shape=[None, None, 3])
    y = image_preprocessing_fn(x, _IMAGE_HEIGHT, _IMAGE_WIDTH)

    sess = tf.Session()
    data = sess.graph.get_tensor_by_name('input:0')
    predict = sess.graph.get_tensor_by_name('resnet_v1_50/predictions/Softmax:0')

    for root, dirs, files in os.walk(file_path):
        for f in files:
            img = cv2.imread(os.path.join(root, f))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # BGR 2 RGB
            img = sess.run(y, feed_dict={x: img})

            img = np.reshape(img, [-1, _IMAGE_HEIGHT, _IMAGE_WIDTH, _IMAGE_CHANNEL])

            out = sess.run(predict, feed_dict={data: img})
            indices = np.argmax(out, 1)

            print('Current image name: %s, predict result: %s' % (f, out))

    sess.close()