使用TensorFlow对象检测API确定最大批量

使用TensorFlow对象检测API确定最大批量,tensorflow,object-detection-api,batchsize,Tensorflow,Object Detection Api,Batchsize,TF Object Detection API默认会获取所有GPU内存,因此很难判断我可以进一步增加多少批处理大小。通常我只是继续增加它,直到我得到一个CUDA OOM错误 另一方面,Pytork在默认情况下不会占用所有GPU内存,因此很容易看出我还有多少百分比需要处理,而无需所有的尝试和错误 有没有更好的方法来确定我缺少的TF对象检测API的批量大小?类似于model_main.py的allow growth标志?我一直在查找源代码,没有找到与此相关的标志 但是,在的文件model_main.

TF Object Detection API默认会获取所有GPU内存,因此很难判断我可以进一步增加多少批处理大小。通常我只是继续增加它,直到我得到一个CUDA OOM错误

另一方面,Pytork在默认情况下不会占用所有GPU内存,因此很容易看出我还有多少百分比需要处理,而无需所有的尝试和错误


有没有更好的方法来确定我缺少的TF对象检测API的批量大小?类似于
model_main.py的
allow growth
标志?

我一直在查找源代码,没有找到与此相关的标志

但是,在的文件
model_main.py
中 您可以找到以下主要函数定义:

def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')
  config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)

  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
...
想法是以类似的方式对其进行修改,如以下方式:

config_proto = tf.ConfigProto()
config_proto.gpu_options.allow_growth = True

config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir, session_config=config_proto)
因此,添加
config_proto
并更改
config
,但保持所有其他条件相同

另外,
allow_growth
使程序可以根据需要使用尽可能多的GPU内存。所以,取决于你的GPU,你可能最终会耗尽所有的内存。在这种情况下,您可能需要使用

config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
它定义了要使用的内存部分

希望这有所帮助

如果您不想修改该文件,似乎应该打开一个问题,因为我没有看到任何标志。除非国旗

flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
                    'file.')

是指与此相关的东西。但我不这么认为,因为它在
model_lib.py
中看起来与训练、评估和推断配置相关,而不是与GPU使用配置相关。

谢谢!我认为这可能是一个很好的解决办法。我明天早上试试,然后回来汇报。