Python TensorFlow:为什么tf.Dataset.map()只处理我的数据集中的第一个示例?

Python TensorFlow:为什么tf.Dataset.map()只处理我的数据集中的第一个示例?,python,tensorflow,dataset,map-function,Python,Tensorflow,Dataset,Map Function,我在TensorFlow 1.12中使用了tf.Dataset.map(): dataset\u train=dataset\u train.map(lambda x:parse\u示例(x、宽度、高度、数量类)) dataset\u train包含592个示例,但这一行只处理其中一个示例,全局计数器证明了这一点,我在parse\u example()中增加了全局计数器。为什么它不处理数据集中的所有示例?我在急切地执行中运行,但是.map()中的代码并没有被急切地执行。任何想法都非常感谢 ---

我在TensorFlow 1.12中使用了
tf.Dataset.map()

dataset\u train=dataset\u train.map(lambda x:parse\u示例(x、宽度、高度、数量类))

dataset\u train
包含592个示例,但这一行只处理其中一个示例,全局计数器证明了这一点,我在
parse\u example()
中增加了全局计数器。为什么它不处理数据集中的所有示例?我在急切地执行中运行,但是
.map()
中的代码并没有被急切地执行。任何想法都非常感谢

---------------------------------------------------------------------------------------

作为参考,我的主要功能如下所示:

tf.enable_eager_execution()

i = 0 # Global counter

tfrecord_train = "/media/nfs/7_raid/ebos/dataset/material_segmentation_train.record"
dataset_train = tf.data.TFRecordDataset(tfrecord_train)

# Read image widht/height from the TFRecord file
iterator = dataset_train.make_one_shot_iterator()
next_element = iterator.get_next()
parsed_element = np.fromstring(next_element.numpy(), dtype=np.uint8)
example = tf.train.Example.FromString(parsed_element)
height = example.features.feature['image/height'].int64_list.value[0]
width = example.features.feature['image/width'].int64_list.value[0]

dataset_train = dataset_train.map(lambda x: parse_example(x, width, height, NUM_CLASSES))
print(v)
def parse_example(example_proto, width, height, num_classes):
    features = {
        'image/encoded': tf.FixedLenFeature((), tf.string),
        'image/height': tf.FixedLenFeature((), tf.int64),
        'image/width': tf.FixedLenFeature((), tf.int64),
        'image/filename': tf.FixedLenFeature((), tf.string),
        'image/object/bbox/xmin': tf.VarLenFeature(tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(tf.float32),
        'image/object/class/label': tf.VarLenFeature(tf.int64),
        'image/object/class/text': tf.VarLenFeature(tf.string),
        'image/object/mask': tf.VarLenFeature(tf.string),
        'image/depth': tf.FixedLenFeature((), tf.string)
    }

    global v
    v = v + 1

    parsed_example = tf.parse_single_example(example_proto, features)
    #filename = parsed_example['image/filename'].numpy().decode("utf-8")

    # Decode image
    image = tf.image.decode_jpeg(parsed_example['image/encoded'])
    parsed_example['image/encoded'] = image

    # Depth + RGBD
    depth = utilities.decode_depth(parsed_example['image/depth'])
    parsed_example['image/depth'] = depth
    rgbd = tf.concat([tf.image.convert_image_dtype(image, tf.float32), depth], axis=2)
    rgbd = tf.reshape(rgbd, shape=tf.stack([height, width, 4]))
    parsed_example['image/rgbd'] = rgbd

    tag_masks = tf.sparse.to_dense(parsed_example['image/object/mask'], default_value="")
    tag_masks = tf.map_fn(utilities.decode_png_mask, tag_masks, dtype=tf.uint8)
    tag_masks = tf.reshape(tag_masks, shape=tf.stack([-1, height, width]), name='tag_masks')

    # All segmentation now have their mask in mask, their labelmap index in classes_indices and their tagname in classes_text
    tag_class_indices = tf.sparse.to_dense(parsed_example['image/object/class/label'])
    tag_class_names = tf.sparse.to_dense(parsed_example['image/object/class/text'], default_value="")
    onehots = masks_to_onehots_tf(tag_masks, tag_class_indices, num_classes)
    parsed_example['image/labels'] = onehots
    print(parsed_example['image/labels'].shape)

    return parsed_example
.map()
中调用的函数如下所示:

tf.enable_eager_execution()

i = 0 # Global counter

tfrecord_train = "/media/nfs/7_raid/ebos/dataset/material_segmentation_train.record"
dataset_train = tf.data.TFRecordDataset(tfrecord_train)

# Read image widht/height from the TFRecord file
iterator = dataset_train.make_one_shot_iterator()
next_element = iterator.get_next()
parsed_element = np.fromstring(next_element.numpy(), dtype=np.uint8)
example = tf.train.Example.FromString(parsed_element)
height = example.features.feature['image/height'].int64_list.value[0]
width = example.features.feature['image/width'].int64_list.value[0]

dataset_train = dataset_train.map(lambda x: parse_example(x, width, height, NUM_CLASSES))
print(v)
def parse_example(example_proto, width, height, num_classes):
    features = {
        'image/encoded': tf.FixedLenFeature((), tf.string),
        'image/height': tf.FixedLenFeature((), tf.int64),
        'image/width': tf.FixedLenFeature((), tf.int64),
        'image/filename': tf.FixedLenFeature((), tf.string),
        'image/object/bbox/xmin': tf.VarLenFeature(tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(tf.float32),
        'image/object/class/label': tf.VarLenFeature(tf.int64),
        'image/object/class/text': tf.VarLenFeature(tf.string),
        'image/object/mask': tf.VarLenFeature(tf.string),
        'image/depth': tf.FixedLenFeature((), tf.string)
    }

    global v
    v = v + 1

    parsed_example = tf.parse_single_example(example_proto, features)
    #filename = parsed_example['image/filename'].numpy().decode("utf-8")

    # Decode image
    image = tf.image.decode_jpeg(parsed_example['image/encoded'])
    parsed_example['image/encoded'] = image

    # Depth + RGBD
    depth = utilities.decode_depth(parsed_example['image/depth'])
    parsed_example['image/depth'] = depth
    rgbd = tf.concat([tf.image.convert_image_dtype(image, tf.float32), depth], axis=2)
    rgbd = tf.reshape(rgbd, shape=tf.stack([height, width, 4]))
    parsed_example['image/rgbd'] = rgbd

    tag_masks = tf.sparse.to_dense(parsed_example['image/object/mask'], default_value="")
    tag_masks = tf.map_fn(utilities.decode_png_mask, tag_masks, dtype=tf.uint8)
    tag_masks = tf.reshape(tag_masks, shape=tf.stack([-1, height, width]), name='tag_masks')

    # All segmentation now have their mask in mask, their labelmap index in classes_indices and their tagname in classes_text
    tag_class_indices = tf.sparse.to_dense(parsed_example['image/object/class/label'])
    tag_class_names = tf.sparse.to_dense(parsed_example['image/object/class/text'], default_value="")
    onehots = masks_to_onehots_tf(tag_masks, tag_class_indices, num_classes)
    parsed_example['image/labels'] = onehots
    print(parsed_example['image/labels'].shape)

    return parsed_example
最后,
masks\u to\u onehot()
如下所示:

def masks_to_onehots_tf(tag_masks, tag_class_indices, num_classes):
    def onehotify(pixel_tag_masks):
        tag_mask_sizes_nozeroidx = tf.where(tf.not_equal(tag_mask_sizes, 0))
        tag_mask_sizes_suppressed = tf.gather_nd(tag_mask_sizes, tag_mask_sizes_nozeroidx)
        smallest_mask_index = tag_mask_sizes_nozeroidx[tf.argmin(tag_mask_sizes_suppressed)]
        onehot = tf.one_hot(smallest_mask_index[0], depth=num_classes, dtype=tf.uint8)
        return onehot
    tag_mask_sizes = tf.reduce_sum(tag_masks, axis=[1, 2])
    image_masks = tf.transpose(tag_masks, perm=[1, 2, 0])
    onehots = tf.map_fn(lambda x: tf.map_fn(onehotify, x), image_masks)
    return onehots

也许,你不能尝试映射函数-它是用来操作张量的。

也许,你不能尝试映射函数-它是用来操作张量的