Python 使用tensorflow数据集洗牌输入文件

Python 使用tensorflow数据集洗牌输入文件,python,tensorflow,dataset,Python,Tensorflow,Dataset,使用旧的输入管道API,我可以做到: filename_queue = tf.train.string_input_producer(filenames, shuffle=True) 然后将文件名传递给其他队列,例如: reader = tf.TFRecordReader() _, serialized_example = reader.read_up_to(filename_queue, n) 如何使用Dataset-API实现类似的行为 tf.data.TFRecordDataset()

使用旧的输入管道API,我可以做到:

filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
然后将文件名传递给其他队列,例如:

reader = tf.TFRecordReader()
_, serialized_example = reader.read_up_to(filename_queue, n)
如何使用Dataset-API实现类似的行为


tf.data.TFRecordDataset()
要求文件名的张量按固定顺序排列。

开始按顺序读取它们,就在以下位置之后:

BUFFER_SIZE = 1000 # arbitrary number
# define filenames somewhere, e.g. via glob
dataset = tf.data.TFRecordDataset(filenames).shuffle(BUFFER_SIZE)
编辑: 的输入管道让我了解了如何使用Dataset API实现文件名洗牌:

dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.shuffle(BUFFER_SIZE) # doesn't need to be big
dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(decode_example, num_parallel_calls=5) # add your decoding logic here
# further processing of the dataset
这将把一个文件的所有数据放在下一个文件的数据之前,依此类推。文件被洗牌,但其中的数据将以相同的顺序生成。 您也可以将
dataset.flat\u map
替换为,以同时处理多个文件并从每个文件返回样本:

dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4)

注意:
交织
实际上并不在多个线程中运行,它是一种循环操作。有关真正的并行处理,请参见当前的Tensorflow版本(2018年2月的1.5版)似乎不支持Dataset API中的文件名本机洗牌。下面是一个使用numpy的简单方法:

import numpy as np
import tensorflow as tf

myShuffledFileList = np.random.choice(myInputFileList, size=len(myInputFileList), replace=False).tolist()

dataset = tf.data.TFRecordDataset(myShuffledFileList)

这会洗牌我的文件或文件中的数据吗?好的,但是当您有一长串TFRecord文件(总共有50000多个示例)包含同一标签(用于深度学习),然后是另一系列包含另一标签示例的文件时,您会怎么做。为了使洗牌工作,您需要一个大于50000的缓冲区,因此需要大量的RAM。这不是一个解决办法。洗牌文件名是一个简单得多的解决方案。我不是建议你把所有东西都打包在一个大文件中,你的用例对我来说似乎非常合理。我要指出的问题是,如果只洗牌文件名,那么每个文件中的数据仍将以相同的顺序读取。我同意洗牌也没什么坏处,但在解码样本后,你仍然需要一个带缓冲区的
shuffle()
,除非你同意让它们总是以相同的顺序排列。@Pekka我认为编辑可能是你想要的,这很有帮助,请记住这段代码是为TF 1.4编写的(我认为,或者接近于此),数据集API从那以后有了巨大的发展,所以今天有些事情可以以更有效的方式实现:)动态加载文件列表:
tf.data.dataset.list\u文件('pattern-here').shuffle(BUFFER\u SIZE)
。硬编码:
tf.data.Dataset.from\u tensor\u slices([filename]).shuffle(BUFFER\u SIZE)
。两者之后必须有一个适当的
.map
,该映射具有打开并读取文件中记录的解码功能。再说一次,在当前的API中,这怎么不可能呢?另外,如果您真的想使用
numpy
np.random.shuffle(myInputFileList)
。也可以从
tf.Data
的开发人员那里了解一下。