Tensorflow tf.contrib.data.DataSet批量大小只能设置为1

Tensorflow tf.contrib.data.DataSet批量大小只能设置为1,tensorflow,tensorflow-datasets,Tensorflow,Tensorflow Datasets,我通过代码将pascal voc数据集转换为tfrecords。我使用tf.contrib.data.Dataset读取数据。我使用的代码如下: import tensorflow as tf from tensorflow.contrib.data import Iterator slim_example_decoder = tf.contrib.slim.tfexample_decoder flags = tf.app.flags flags.DEFINE_string('data_di

我通过代码将pascal voc数据集转换为tfrecords。我使用
tf.contrib.data.Dataset
读取数据。我使用的代码如下:

import tensorflow as tf
from tensorflow.contrib.data import Iterator

slim_example_decoder = tf.contrib.slim.tfexample_decoder

flags = tf.app.flags
flags.DEFINE_string('data_dir', '/home/aurora/workspaces/data/tfrecords_data/voc_dataset/trainval.tfrecords',
                'tfrecords file output path')
flags.DEFINE_integer('batch_size', 1, 'training batch size')
flags.DEFINE_integer('capacity', 10000, 'training batch size')
FLAGS = flags.FLAGS

features = {"image/height": tf.FixedLenFeature((), tf.int64, default_value=1),
        "image/width": tf.FixedLenFeature((), tf.int64, default_value=1),
        "image/filename": tf.FixedLenFeature((), tf.string, default_value=""),
        "image/source_id": tf.FixedLenFeature((), tf.string, default_value=""),
        "image/key/sha256": tf.FixedLenFeature((), tf.string, default_value=""),
        "image/encoded": tf.FixedLenFeature((), tf.string, default_value=""),
        "image/format": tf.FixedLenFeature((), tf.string, default_value="jpeg"),
        "image/object/object_number": tf.FixedLenFeature((), tf.int64, default_value=1),
        "image/object/bbox/xmin": tf.VarLenFeature(tf.float32),
        "image/object/bbox/xmax": tf.VarLenFeature(tf.float32),
        "image/object/bbox/ymin": tf.VarLenFeature(tf.float32),
        "image/object/bbox/ymax": tf.VarLenFeature(tf.float32),
        "image/object/class/text": tf.VarLenFeature(tf.string),
        "image/object/class/label": tf.VarLenFeature(tf.int64),
        "image/object/difficult": tf.VarLenFeature(tf.int64),
        "image/object/truncated": tf.VarLenFeature(tf.int64),
        "image/object/view": tf.VarLenFeature(tf.string),
      }

items_to_handlers = {
    'image': slim_example_decoder.Image(
        image_key='image/encoded', format_key='image/format', channels=3),
    'height': (
        slim_example_decoder.Tensor('image/height')),
    'width': (
        slim_example_decoder.Tensor('image/width')),
    'source_id': (
        slim_example_decoder.Tensor('image/source_id')),
    'key': (
        slim_example_decoder.Tensor('image/key/sha256')),
    'filename': (
        slim_example_decoder.Tensor('image/filename')),
    # Object boxes and classes.
    'groundtruth_boxes': (
        slim_example_decoder.BoundingBox(
            ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')),
    'groundtruth_classes': (
        slim_example_decoder.Tensor('image/object/class/label')),
    'groundtruth_difficult': (
        slim_example_decoder.Tensor('image/object/difficult')),
    'image/object/truncated': (
        slim_example_decoder.Tensor('image/object/truncated')),
    }

decoder = slim_example_decoder.TFExampleDecoder(features, items_to_handlers)
keys = decoder.list_items()


def _parse_function_train(example):
    serialized_example = tf.reshape(example, shape=[])
    decoder = slim_example_decoder.TFExampleDecoder(features, items_to_handlers)
    keys = decoder.list_items()
    tensors = decoder.decode(serialized_example, items=keys)
    tensor_dict = dict(zip(keys, tensors))
    tensor_dict['image'].set_shape([None, None, 3])
    # tensor_dict['image'] = tf.expand_dims(tensor_dict['image'], 0)
    images = tensor_dict['image']
    float_images = tf.cast(images, tf.uint8)
    tensor_dict['image'] = float_images
    return tensor_dict


def build_pipleline(train_data_dir, test_data_dir, batch_size, capacity):
    train_dataset = tf.contrib.data.TFRecordDataset(train_data_dir)
    train_dataset = train_dataset.map(_parse_function_train)
    train_dataset = train_dataset.repeat(1)
    train_dataset = train_dataset.batch(batch_size)
    train_dataset = train_dataset.shuffle(buffer_size=capacity)

    iterator = Iterator.from_structure(train_dataset.output_types,
                                   train_dataset.output_shapes)
    next_element = iterator.get_next()
    training_init_op = iterator.make_initializer(train_dataset)

    return training_init_op, next_element 


if __name__ == '__main__':
    # TODO: only support batch size 1
    training_init_op, next_element = build_pipleline(FLAGS.data_dir, None, FLAGS.batch_size, FLAGS.capacity)
    sess = tf.Session()
    sess.run(training_init_op)
    counter = 0
    while True:
        try:
            next_element_val = sess.run(next_element)
            print(next_element_val['image'].shape, next_element_val['filename'])
            print(next_element_val['groundtruth_boxes'])
            print('-'*30)
            counter += 1
        except tf.errors.OutOfRangeError:
            print('End of training data in step %d' %counter)
            break
当批大小设置为1时,代码可以正确运行,当我将批大小更改为大于1时,代码将出现错误。 错误如下所示:

/usr/software/anaconda3/bin/python3.6 /home/aurora/workspaces/PycharmProjects/object_detection_models/builder/voc_input_pipline_dataset_builder.py
2017-10-11 15:55:05.886856: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2017-10-11 15:55:05.886869: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-10-11 15:55:05.886872: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2017-10-11 15:55:05.886874: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2017-10-11 15:55:05.886876: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
2017-10-11 15:55:05.974850: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:893] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2017-10-11 15:55:05.975103: I tensorflow/core/common_runtime/gpu/gpu_device.cc:955] Found device 0 with properties: 
name: GeForce GTX 1080 Ti
major: 6 minor: 1 memoryClockRate (GHz) 1.683
pciBusID 0000:01:00.0
Total memory: 10.90GiB
Free memory: 10.46GiB
2017-10-11 15:55:05.975112: I tensorflow/core/common_runtime/gpu/gpu_device.cc:976] DMA: 0 
2017-10-11 15:55:05.975114: I tensorflow/core/common_runtime/gpu/gpu_device.cc:986] 0:   Y 
2017-10-11 15:55:05.975118: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:0) -> (device: 0,       name: GeForce GTX 1080 Ti, pci bus id: 0000:01:00.0)
2017-10-11 15:55:06.027798: W tensorflow/core/framework/op_kernel.cc:1192] Internal: HandleElementToSlice Cannot copy slice: number of elements does not match.  Shapes are: [element]: [1,4], [parent slice]: [5,4]
Traceback (most recent call last):
  File "/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1327, in _do_call
return fn(*args)
  File "/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1306, in _run_fn
status, run_metadata)
  File "/usr/software/anaconda3/lib/python3.6/contextlib.py", line 89, in __exit__
next(self.gen)
  File "/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
 tensorflow.python.framework.errors_impl.InternalError: HandleElementToSlice Cannot copy slice: number of elements does not match.  Shapes are: [element]: [1,4], [parent slice]: [5,4]
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?,?,4], [?,?], [?,?], [?], [?,?,?,3], [?,?], [?], [?], [?]], output_types=[DT_STRING, DT_FLOAT, DT_INT64, DT_INT64, DT_INT64, DT_UINT8, DT_INT64, DT_STRING, DT_STRING, DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/aurora/workspaces/PycharmProjects/object_detection_models/builder/voc_input_pipline_dataset_builder.py", line 98, in <module>
next_element_val = sess.run(next_element)
  File "/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run
run_metadata_ptr)
  File "/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1124, in _run
feed_dict_tensor, options, run_metadata)
  File "/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1321, in _do_run
options, run_metadata)
  File "/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1340, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: HandleElementToSlice Cannot copy slice: number of elements does not match.  Shapes are: [element]: [1,4], [parent slice]: [5,4]
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?,?,4], [?,?], [?,?], [?], [?,?,?,3], [?,?], [?], [?], [?]], output_types=[DT_STRING, DT_FLOAT, DT_INT64, DT_INT64, DT_INT64, DT_UINT8, DT_INT64, DT_STRING, DT_STRING, DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]

Caused by op 'IteratorGetNext', defined at:
  File "/home/aurora/workspaces/PycharmProjects/object_detection_models/builder/voc_input_pipline_dataset_builder.py", line 92, in <module>
training_init_op, next_element = build_pipleline(FLAGS.data_dir, None, FLAGS.batch_size, FLAGS.capacity)
  File "/home/aurora/workspaces/PycharmProjects/object_detection_models/builder/voc_input_pipline_dataset_builder.py", line 84, in build_pipleline
next_element = iterator.get_next()
  File "/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 304, in get_next
name=name))
  File "/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 379, in iterator_get_next
output_shapes=output_shapes, name=name)
  File "/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
  File "/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2630, in create_op
original_op=self._default_original_op, op_def=op_def)
  File "/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1204, in __init__
self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InternalError (see above for traceback): HandleElementToSlice Cannot copy slice: number of elements does not match.  Shapes are: [element]: [1,4], [parent slice]: [5,4]
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?,?,4], [?,?], [?,?], [?], [?,?,?,3], [?,?], [?], [?], [?]], output_types=[DT_STRING, DT_FLOAT, DT_INT64, DT_INT64, DT_INT64, DT_UINT8, DT_INT64, DT_STRING, DT_STRING, DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]
/usr/software/anaconda3/bin/python3.6/home/aurora/workspace/pycharm项目/object\u detection\u models/builder/voc\u input\u pipline\u dataset\u builder.py
2017-10-11 15:55:05.886856:W tensorflow/core/platform/cpu_feature_guard.cc:45]tensorflow库的编译不是为了使用SSE4.1指令,但这些指令在您的机器上可用,可以加快cpu计算。
2017-10-11 15:55:05.886869:W tensorflow/core/platform/cpu_feature_guard.cc:45]tensorflow库的编译不是为了使用SSE4.2指令,但这些指令在您的机器上可用,可以加快cpu计算。
2017-10-11 15:55:05.886872:W tensorflow/core/platform/cpu_feature_guard.cc:45]tensorflow库的编译不是为了使用AVX指令,但这些指令在您的机器上可用,可以加快cpu计算。
2017-10-11 15:55:05.886874:W tensorflow/core/platform/cpu_feature_guard.cc:45]tensorflow库的编译不是为了使用AVX2指令,但这些指令在您的机器上可用,可以加快cpu计算。
2017-10-11 15:55:05.886876:W tensorflow/core/platform/cpu_feature_guard.cc:45]tensorflow库的编译不是为了使用FMA指令,但这些指令在您的机器上可用,可以加快cpu计算。
2017-10-11 15:55:05.974850:I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:893]从SysFS读取的成功NUMA节点的值为负值(-1),但必须至少有一个NUMA节点,因此返回NUMA节点零
2017-10-11 15:55:05.975103:I tensorflow/core/common_runtime/gpu/gpu_device.cc:955]找到了具有以下属性的设备0:
名称:GeForce GTX 1080 Ti
大调:6小调:1记忆时钟频率(GHz)1.683
pciBusID 0000:01:00.0
总内存:10.90GiB
可用内存:10.46GiB
2017-10-11 15:55:05.975112:I tensorflow/core/common_runtime/gpu/gpu_device.cc:976]DMA:0
2017-10-11 15:55:05.975114:I tensorflow/core/common_runtime/gpu/gpu_device.cc:986]0:Y
2017-10-11 15:55:05.975118:I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045]创建tensorflow设备(/gpu:0)->(设备:0,名称:GeForce GTX 1080 Ti,pci总线id:0000:01:00.0)
2017-10-11 15:55:06.027798:W tensorflow/core/framework/op_kernel.cc:1192]内部:HandleElementToSlice无法复制切片:元素数不匹配。形状是:[元素]:[1,4],[父切片]:[5,4]
回溯(最近一次呼叫最后一次):
文件“/usr/software/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第1327行,在
返回fn(*args)
文件“/usr/software/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第1306行,在
状态,运行(元数据)
文件“/usr/software/anaconda3/lib/python3.6/contextlib.py”,第89行,在__
下一个(self.gen)
文件“/usr/software/anaconda3/lib/python3.6/site packages/tensorflow/python/framework/errors\u impl.py”,第466行,处于raise\u exception\u on\u not\u ok\u状态
pywrap_tensorflow.TF_GetCode(状态))
tensorflow.python.framework.errors\u impl.InternalError:HandleElementToSlice无法复制切片:元素数不匹配。形状是:[元素]:[1,4],[父切片]:[5,4]
[[Node:IteratorGetNext=IteratorGetNext[output_shapes=[[?]、[?、?、4]、[?]、[?]、[?]、[?]、[?]、[?]、[?]、[?]、[?]、输出_类型=[DT_字符串、DT_浮点、DT_INT64、DT_INT64、DT_UINT8、DT_INT64、DT_字符串、DT_字符串、DT_INT64]、DT_设备=“/job:localhost/0/task:0:Iterator”]
在处理上述异常期间,发生了另一个异常:
回溯(最近一次呼叫最后一次):
文件“/home/aurora/workspace/PycharmProjects/object_detection_models/builder/voc_input_pipline_dataset_builder.py”,第98行,在
下一个元素\u val=sess.run(下一个元素)
文件“/usr/software/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第895行,正在运行
运行_元数据_ptr)
文件“/usr/software/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第1124行,正在运行
feed_dict_tensor、options、run_元数据)
文件“/usr/software/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第1321行,运行
选项,运行(元数据)
文件“/usr/software/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第1340行,在
提升类型(e)(节点定义、操作、消息)
tensorflow.python.framework.errors\u impl.InternalError:HandleElementToSlice无法复制切片:元素数不匹配。形状是:[元素]:[1,4],[父切片]:[5,4]
[[Node:IteratorGetNext=IteratorGetNext[output_shapes=[[?]、[?、?、4]、[?]、[?]、[?]、[?]、[?]、[?]、[?]、[?]、[?]、输出_类型=[DT_字符串、DT_浮点、DT_INT64、DT_INT64、DT_UINT8、DT_INT64、DT_字符串、DT_字符串、DT_INT64]、DT_设备=“/job:localhost/0/task:0:Iterator”]
由op“IteratorGetNext”引起,定义于:
文件“/home/aurora/workspace/PycharmProjects/object_detection_models/builder/voc_input_pipline_dataset_builder.py”,第92行,在
training_init_op,next_element=build_pipleline(FLAGS.data_dir,None,FLAGS.batch_size,FLAGS.capacity)
文件“/home/aurora/workspace/PycharmProjects/object\u detection\u models/builder/voc\u input\u pipline\u dataset\u builder.py”,第84行,in build\u piplline
下一个元素=迭代器。获取下一个元素()
文件“/usr/software/anaconda3/lib/python3.6/site packages/tensorflow/contrib/data/python/ops/dataset_ops.py”,第304行,在get_next中
名称=名称)
迭代器中的文件“/usr/software/anaconda3/lib/python3.6/site packages/tensorflow/python/ops/gen_dataset_ops.py”,第379行
输出形状=输出形状,名称=名称)
文件“/usr/software/anaconda3/lib/python3.6/site-packages/tensorflow/python/frame
dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))  # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element))  # ==> [[4, 4, 4, 4, 0, 0, 0],
                               #      [5, 5, 5, 5, 5, 0, 0],
                               #      [6, 6, 6, 6, 6, 6, 0],
                               #      [7, 7, 7, 7, 7, 7, 7]]