Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Tensorflow 使用[image,label]=dataset.take(2)返回两个元组,而不是一个元组_Tensorflow_Tensorflow2.0_Tensorflow Datasets - Fatal编程技术网

Tensorflow 使用[image,label]=dataset.take(2)返回两个元组,而不是一个元组

Tensorflow 使用[image,label]=dataset.take(2)返回两个元组,而不是一个元组,tensorflow,tensorflow2.0,tensorflow-datasets,Tensorflow,Tensorflow2.0,Tensorflow Datasets,我有一个TFRecord文件,其中存储了图像,将字节包装为字符串,标签为ints64。我正在使用下面的代码来操作图像和标签: #从TFRecord文件创建数据集 记录\u路径=数据\u目录+'TFRecords/train\u 0.TFRecords' dataset=tf.data.TFRecordDataset(文件名=记录\路径) #从解析函数映射数据集 parsed_dataset=dataset.map(parsed_fn) 打印(已解析的_数据集) #取一个测试样本 图像,标签=已解

我有一个TFRecord文件,其中存储了图像,将字节包装为字符串标签为ints64。我正在使用下面的代码来操作图像和标签:

#从TFRecord文件创建数据集
记录\u路径=数据\u目录+'TFRecords/train\u 0.TFRecords'
dataset=tf.data.TFRecordDataset(文件名=记录\路径)
#从解析函数映射数据集
parsed_dataset=dataset.map(parsed_fn)
打印(已解析的_数据集)
#取一个测试样本
图像,标签=已解析的_数据集。获取(2)
打印(图像、标签)
哪些产出:

Tensor("ParseSingleExample/ParseSingleExample:1", shape=(), dtype=int64)
<MapDataset shapes: ((None,), ()), types: (tf.float32, tf.int64)>

((<tf.Tensor: id=635, shape=(185256,), dtype=float32, numpy=array([162., 162., 170., ...,  17.,  17., 255.], dtype=float32)>,
  <tf.Tensor: id=636, shape=(), dtype=int64, numpy=183350>),
 (<tf.Tensor: id=637, shape=(153120,), dtype=float32, numpy=array([208., 207., 202., ..., 240., 240., 242.], dtype=float32)>,
  <tf.Tensor: id=638, shape=(), dtype=int64, numpy=183350>))

该方法用于创建数据集。它不会从数据集中提取元素

如果要从数据中提取元素,可以使用:
tf.compat.v1.data.make\u one\u shot\u iterator()
我没有找到一种更干净的方法从数据集中提取元素

例如:

iterator = tf.compat.v1.data.make_one_shot_iterator(parsed_dataset)
image, label = iterator.get_next()