Tensorflow 如何用多进程恢复模型进行预测

Tensorflow 如何用多进程恢复模型进行预测,tensorflow,conv-neural-network,Tensorflow,Conv Neural Network,我已经保存了培训模型 预测代码如下: def predict(predict_filename): with tf.Session() as sess: # normalization image ... im_norm = tf.image.per_image_standardization(rgb_image) sess = tf.InteractiveSession() with sess.as_

我已经保存了培训模型

预测代码如下:

def predict(predict_filename):
    with tf.Session() as sess:

        # normalization image
         ...
        im_norm = tf.image.per_image_standardization(rgb_image)

        sess = tf.InteractiveSession()
        with sess.as_default():

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord, sess=sess)

            def get_ckpt():
                ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
                return ckpt

            ckpt = get_ckpt()

            def get_saver():
                new_saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
                return new_saver

            new_saver = get_saver()

            if ckpt and ckpt.model_checkpoint_path:
                print(ckpt.model_checkpoint_path)
                new_saver.restore(sess, ckpt.model_checkpoint_path)
                try:
                    while not coord.should_stop():

                         result = sess.run("Linear_regression/Add:0", feed_dict={'conv1/Reshape:0': image_new.eval(),                                                                         'dropout2/dropout/keep_prob:0': 0.2})
                        ...
                except tf.errors.OutOfRangeError:
                    print('Done training after reading all data')
                finally:  # When done, ask the threads to stop.
                    coord.request_stop()
                    # Wait for threads to finish.
                    # tf.reset_default_graph()
                    coord.join(threads)
                    sess.close()
                    # os._exit(0)
            else:
                print('no checkpoint found')
我使用星图多进程获得预测值:

 def main(argv=None):  
        filename = []
        filename.append('./1.jpg')
        filename.append('./2.jpg')
        filename.append('./3.jpg')
        filename = np.array(filename).astype(str)
        print(filename)
        pool = ThreadPool(1)
        pool.starmap(predict, zip(filename))
        pool.join()
        pool.close()
这里我只使用了一个过程。运行项目,成功获取“1.jpg”的预测值,然后获取错误:

tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value file_queue/limit_epochs/epochs
     [[Node: file_queue/limit_epochs/CountUpTo = CountUpTo[T=DT_INT64, _class=["loc:@file_queue/limit_epochs/epochs"], limit=40, _device="/job:localhost/replica:0/task:0/cpu:0"](file_queue/limit_epochs/epochs)]]
我真的不知道为什么第二张图片的预测会出错,但是第一张图片可以得到正确的结果

当我添加操作系统时,在“最终”中退出(0),一切顺利,三张图片都能得到正确的结果,为什么

如果添加如果我添加sess.run(tf.global_variables_initializer()),在预测第一张图片之后,然后获取日志:Traceback(最近一次调用last):


您能仔细检查输出并检查是否出现内存耗尽错误吗?@drpng我确信内存足够,并且没有得到任何与此相关的日志。创建会话后,您需要一个
sess.run(tf.global\u variables\u initializer())
。另外,我不确定您是否可以运行嵌套会话,所以我会尝试使用它。如果在预测第一张图片之后添加sess.run(tf.global\u variables\u initializer()),则会得到错误“找不到检查点”。我更新了问题的整个日志。@drpng我认为graph init有问题,但我没有找到正确的解决方法,需要您的帮助吗~
 File "C:\Program Files\Anaconda3\lib\multiprocessing\pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "C:\Program Files\Anaconda3\lib\multiprocessing\pool.py", line 47, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
  File "D:\PycharmProjects\IQA\paralle_test.py", line 109, in predict
    getvalue(file)
  File "D:\PycharmProjects\IQA\paralle_test.py", line 104, in getvalue
    print('no checkpoint found')
  File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1208, in __exit__
    self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb)
  File "C:\Program Files\Anaconda3\lib\contextlib.py", line 66, in __exit__
    next(self.gen)
  File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 3520, in get_controller
    if self.stack[-1] is not default:
IndexError: list index out of range
"""\tensor
flo