Tensorflow tf.data.Dataset中的队列容量

Tensorflow tf.data.Dataset中的队列容量,tensorflow,Tensorflow,我对Tensorflow的新输入管道机制有问题。当我使用tf.data.Dataset创建数据管道时,它会解码jpeg图像,然后将它们加载到队列中,它会尝试将尽可能多的图像加载到队列中。如果加载图像的吞吐量大于我的模型处理的图像的吞吐量,那么内存使用将无限增加 下面是使用tf.data.Dataset构建管道的代码片段 def _imread(file_name, label): _raw = tf.read_file(file_name) _decoded = tf.image.dec

我对Tensorflow的新输入管道机制有问题。当我使用tf.data.Dataset创建数据管道时,它会解码jpeg图像,然后将它们加载到队列中,它会尝试将尽可能多的图像加载到队列中。如果加载图像的吞吐量大于我的模型处理的图像的吞吐量,那么内存使用将无限增加

下面是使用tf.data.Dataset构建管道的代码片段

def _imread(file_name, label):
  _raw = tf.read_file(file_name)
  _decoded = tf.image.decode_jpeg(_raw, channels=hps.im_ch)
  _resized = tf.image.resize_images(_decoded, [hps.im_width, hps.im_height])
  _scaled = (_resized / 127.5) - 1.0
  return _scaled, label

n_samples = image_files.shape.as_list()[0]
dset = tf.data.Dataset.from_tensor_slices((image_files, labels))
dset = dset.shuffle(n_samples, None)
dset = dset.repeat(hps.n_epochs)
dset = dset.map(_imread, hps.batch_size * 32)
dset = dset.batch(hps.batch_size)
dset = dset.prefetch(hps.batch_size * 2)
这里的
image\u files
是一个常数张量,包含30k个图像的文件名。图像大小调整为256x256x3英寸

如果使用以下代码段生成管道:

# refer to "https://www.tensorflow.org/programmers_guide/datasets"
def _imread(file_name, hps):
  _raw = tf.read_file(file_name)
  _decoded = tf.image.decode_jpeg(_raw, channels=hps.im_ch)
  _resized = tf.image.resize_images(_decoded, [hps.im_width, hps.im_height])
  _scaled = (_resized / 127.5) - 1.0
  return _scaled

n_samples = image_files.shape.as_list()[0]

image_file, label = tf.train.slice_input_producer(
  [image_files, labels],
  num_epochs=hps.n_epochs,
  shuffle=True,
  seed=None,
  capacity=n_samples,
)

# Decode image.
image = _imread(image_file, 

images, labels = tf.train.shuffle_batch(
  tensors=[image, label],
  batch_size=hps.batch_size,
  capacity=hps.batch_size * 64,
  min_after_dequeue=hps.batch_size * 8,
  num_threads=32,
  seed=None,
  enqueue_many=False,
  allow_smaller_final_batch=True
)

在整个训练过程中,记忆的使用几乎是恒定的。如何使tf.data.Dataset加载固定数量的样本?我用tf.data.Dataset创建的管道是否正确?我认为tf.data.Dataset.shuffle中的
buffer\u size
参数用于
image\u文件
标签
。所以存储30k字符串应该不会有问题,对吧?即使要加载30k图像,也需要
30000*256*256*3*8/(1024*1024*1024)=
43GB内存。但是,它使用59GB的61GB系统内存。

这将缓冲n_样本,这看起来是您的整个数据集。您可能需要减少这里的缓冲

dset = dset.shuffle(n_samples, None)
你最好永远重复,重复不会缓冲()

您正在批处理,然后预取hps。批处理的批大小。哎哟

dset = dset.batch(hps.batch_size)
dset = dset.prefetch(hps.batch_size * 2)
让我们假设hps.batch_size=1000来做一个具体的例子。上面的第一行创建了一批1000个图像。上面的第二行每1000个图像创建2000批,缓冲总计2000000个图像。哎呀

你的意思是:

dset = dset.batch(hps.batch_size)
dset = dset.prefetch(2)

那么预回迁是我失败的地方?它在文档中表示“表示预取时缓冲的最大元素数”。我认为元素是数据集流中的单个实例,就像我的例子中的(图像、标签)对一样。谢谢你给我澄清!我用我的这种方式,预取2批,它的工作如预期。所以我很有信心我是对的。批量将单个样本更改为单个样本组是有意义的。
dset = dset.batch(hps.batch_size)
dset = dset.prefetch(2)