Tensorflow对象检测中止,InvalidArgumentError:索引[0]=2不在[0,1]中
我试图在自己的数据集上训练tensorflow对象检测 我做了什么?Tensorflow对象检测中止,InvalidArgumentError:索引[0]=2不在[0,1]中,tensorflow,object-detection,Tensorflow,Object Detection,我试图在自己的数据集上训练tensorflow对象检测 我做了什么? 使用ssd\u mobilenet\u v1\u pets.config作为基础,创建我自己的管道配置。调整num\u类和所有其他路径特定部分,以匹配我的环境 从as检查点使用ssd_mobilenet_v1_coco 已创建包含所有标签的标签映射文件(第一个索引从1开始) 从我的数据集创建了一个TFRecord文件(脚本基于) 出了什么问题? 开始培训时,请: python tensorflow\u models/res
- 使用
作为基础,创建我自己的管道配置。调整ssd\u mobilenet\u v1\u pets.config
和所有其他路径特定部分,以匹配我的环境num\u类
- 从as检查点使用ssd_mobilenet_v1_coco
- 已创建包含所有标签的标签映射文件(第一个索引从1开始)
- 从我的数据集创建了一个
文件(脚本基于)TFRecord
python tensorflow\u models/research/object\u detection/train.py--pipeline\u config\u path=/home/playway/ssd\u mobilenet\u v1.config--train\u dir=/tmp/bla/
Traceback (most recent call last):
File "tensorflow_models/research/object_detection/train.py", line 198, in <module>
tf.app.run()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "tensorflow_models/research/object_detection/train.py", line 194, in main
worker_job_name, is_chief, FLAGS.train_dir)
File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 296, in train
saver=saver)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/slim/python/slim/learning.py", line 767, in train
sv.stop(threads, close_summary_writer=True)
File "/usr/lib/python2.7/contextlib.py", line 35, in __exit__
self.gen.throw(type, value, traceback)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/supervisor.py", line 964, in managed_session
self.stop(close_summary_writer=close_summary_writer)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/supervisor.py", line 792, in stop
stop_grace_period_secs=self._stop_grace_secs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py", line 389, in join
six.reraise(*self._exc_info_to_raise)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/queue_runner_impl.py", line 238, in _run
enqueue_callable()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1235, in _single_operation_run
target_list_as_strings, status, None)
File "/usr/lib/python2.7/contextlib.py", line 24, in __exit__
self.gen.next()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0] = 2 is not in [0, 1)
[[Node: cond/RandomCropImage/PruneCompleteleyOutsideWindow/Gather/Gather_1 = Gather[Tindices=DT_INT64, Tparams=DT_INT64, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](cond/RandomCropImage/PruneCompleteleyOutsideWindow/Gather/Gather_1/Switch:1, cond/RandomCropImage/PruneCompleteleyOutsideWindow/Reshape)]]
edit添加了一些代码,显示如何生成TFRecord
文件。整个脚本稍微长一点,但我尝试将其缩减为仅显示相关部分。如果遗漏了您感兴趣的内容,请告诉我
CATEGORIES_TO_TRAIN = ["apple", "dog", "cat"]
def createTFExample(img):
imageFormat = ""
if img.format == 'JPEG':
imageFormat = b'jpeg'
elif img.format == 'PNG':
imageFormat = b'png'
else:
print 'Unknown Image format %s' %(img.format,)
return None
width, height = img.size
filename = str(img.filename)
encodedImageData = img.bytesIO
xmins = []
xmaxs = []
ymins = []
ymaxs = []
for annotation in img.annotations:
xmins.append((annotation.left / width))
xmaxs.append((annotation.left + annotation.width) / width)
ymins.append((annotation.top / height))
ymaxs.append((annotation.top + annotation.height) / height)
#we might have some images in our dataset, which don't have a annotation, skip those
if((len(xmins) == 0) or (len(xmaxs) == 0) or (len(ymins) == 0) or (len(ymaxs) == 0)):
return None
label = [img.label.encode('utf8')]
classes = [(CATEGORIES_TO_TRAIN.index(img.label) + 1)] #class indexes start with 1
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encodedImageData),
'image/format': dataset_util.bytes_feature(imageFormat),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(label),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return tf_example
def createTfRecordFile(images):
writer = tf.python_io.TFRecordWriter(TFRECORD_OUTPUT_PATH)
for img in images:
t = createTFExample(img)
if t is not None:
writer.write(t.SerializeToString())
writer.close()
非常感谢为我指明正确方向的任何帮助!我也遇到了类似的问题,但是让
标签
列表和类
列表与边界框元素具有相同的长度为我解决了这个问题
具体而言,在createTFExample()
中,label=[img.label.encode('utf8')]
和classes=[(CATEGORIES\u TO\u TRAIN.index(img.label)+1)]
中的元素应与边界框注释列表的元素相对应:
xmins = []
xmaxs = []
ymins = []
ymaxs = []
for annotation in img.annotations:
xmins.append((annotation.left / width))
xmaxs.append((annotation.left + annotation.width) / width)
ymins.append((annotation.top / height))
ymaxs.append((annotation.top + annotation.height) / height)
根据您的代码结构,我假设每个img
对象有一个对象类型,但在这种情况下,编写
label = [img.label.encode('utf8')] * len(xmins)
classes = [(CATEGORIES_TO_TRAIN.index(img.label) + 1)] * len(xmins)
或者使用图像中提供对象数量的任何元素,以便标签和类以及边界框列表具有相同的长度。如果一个
img
对象中存在多种类型的对象,则创建一个对象名称和类别ID列表,其中内部元素的索引与注释列表的索引匹配
结果列表应如下所示:
xmins = [a_xmin, b_xmin, c_xmin]
ymins = [a_ymin, b_ymin, c_ymin]
xmaxs = [a_xmax, b_xmax, c_xmax]
ymaxs = [a_ymax, b_ymax, c_ymax]
labels = [a_label, b_label, c_label]
classes = [a_classid, b_classid, c_classid]
这解决了我的问题,希望这能有所帮助!我可能完全错了,但我认为pets.config中建议的设置与您使用的coco模型是分开的,因此它会对形状产生抱怨。请尝试改用coco模型设置。让我知道它是否有效谢谢@eshirima的建议-非常感谢!我刚刚试过,但不幸的是,它现在失败了,错误消息略有不同(我用回溯更新了原始帖子)在标签映射中,是否还有索引为0的类作为背景/无对象?@gdelab感谢您的帮助!我的标签映射看起来类似于,但不是37个项目,而是只有3个项目。我的标签映射也从1开始,因此那里没有索引为0的类。我需要一个吗?我想是的,但显然我错了。对不起,回复太晚了。你完全正确,这完全解决了我的问题。非常感谢你的帮助,非常感谢!
xmins = [a_xmin, b_xmin, c_xmin]
ymins = [a_ymin, b_ymin, c_ymin]
xmaxs = [a_xmax, b_xmax, c_xmax]
ymaxs = [a_ymax, b_ymax, c_ymax]
labels = [a_label, b_label, c_label]
classes = [a_classid, b_classid, c_classid]