Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
tensorflow数据集中图像元组的展平_Tensorflow_Tensorflow Datasets - Fatal编程技术网

tensorflow数据集中图像元组的展平

tensorflow数据集中图像元组的展平,tensorflow,tensorflow-datasets,Tensorflow,Tensorflow Datasets,我有一个从tfrecords读取的三重图像数据集,我使用以下代码将其转换为数据集 def parse_dataset(record): def convert_raw_to_image_tensor(raw): raw = tf.io.decode_base64(raw) image_shape = tf.stack([299, 299, 3]) decoded = tf.io.decode_imag

我有一个从tfrecords读取的三重图像数据集,我使用以下代码将其转换为数据集

    def parse_dataset(record):
        def convert_raw_to_image_tensor(raw):
            raw = tf.io.decode_base64(raw)
            image_shape = tf.stack([299, 299, 3])
            decoded = tf.io.decode_image(raw, channels=3, 
                                dtype=tf.uint8, expand_animations=False)
            decoded = tf.cast(decoded, tf.float32)
            decoded = tf.reshape(decoded, image_shape)
            decoded = tf.math.divide(decoded, 255.)
            return decoded

        features = {
            'n': tf.io.FixedLenFeature([], tf.string),
            'p': tf.io.FixedLenFeature([], tf.string),
            'q': tf.io.FixedLenFeature([], tf.string)
        }
        sample = tf.io.parse_single_example(record, features)
        neg_image = sample['n']
        pos_image = sample['p']
        query_image = sample['q']

        neg_decoded = convert_raw_to_image_tensor(neg_image)
        pos_decoded = convert_raw_to_image_tensor(pos_image)
        query_decoded = convert_raw_to_image_tensor(query_image)
        return (neg_decoded, pos_decoded, query_decoded)

    record_dataset = tf.data.TFRecordDataset(filenames=path_dataset, num_parallel_reads=4)
    record_dataset = record_dataset.map(parse_dataset)
此结果数据集的形状为

<MapDataset shapes: ((299, 299, 3), (299, 299, 3), (299, 299, 3)), types: (tf.float32, tf.float32, tf.float32)>
结果数据集具有这种形状

<FlatMapDataset shapes: ((299, 3), (299, 3), (299, 3)), types: (tf.float32, tf.float32, tf.float32)>

我认为您只是错误地解包了元组

这应该做到:

def展平(*x):
从_张量_切片([i代表x中的i])返回tf.data.Dataset
展平=记录数据集。展平地图(展平)
以便:

展平中的i的
:
印刷品(i.形状)
给出:

(299, 299, 3)
(299, 299, 3)
(299, 299, 3)
(299, 299, 3)
...
果然

(299, 299, 3)
(299, 299, 3)
(299, 299, 3)
(299, 299, 3)
...