Python 使用tf.data.Dataset时如何解决OOM错误?

Python 使用tf.data.Dataset时如何解决OOM错误?,python,tensorflow,Python,Tensorflow,我正在使用tf.data.Dataset API构建一个数据管道,但出现OOM错误。假设我手头已经有功能和标签,它们是按[N,H,W,C]顺序排列的4D numpy数组。以下是如何创建我的数据集对象: batch_size = 100 num_samples = features.shape[0] # number of training samples features_placeholder = tf.placeholder(tf.float32, [None, feature_size]

我正在使用tf.data.Dataset API构建一个数据管道,但出现OOM错误。假设我手头已经有
功能
标签
,它们是按[N,H,W,C]顺序排列的4D numpy数组。以下是如何创建我的
数据集
对象:

batch_size = 100
num_samples = features.shape[0] # number of training samples

features_placeholder = tf.placeholder(tf.float32, [None, feature_size], name='features_placeholder')
labels_placeholder = tf.placeholder(tf.float32, [None, label_count], name='labels_placeholder')

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
dataset = dataset.batch(batch_size)
dataset = dataset.shuffle(num_samples)
dataset = dataset.prefetch(buffer_size=1)
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
我使用
tf.placeholder
的原因可以参考,这基本上建议使用
tf.placeholder
定义
dataset
以节省内存,如果数据是大的numpy数组(我的训练数据集中有54368个样本)。培训部分如下所示:

for i in range(epoch):
    sess.run([init_op, optimizer], 
             feed_dict={features_placeholder:features, labels_placeholder:labels]}
但我有一个错误说:

OOM通过分配器GPU\U 0\U bfc分配形状为[54368,40,3,64]且类型为float on/job:localhost/replica:0/task:0/device:GPU:0的张量时


正如我追溯的那样,它发生在我的模型中定义的
tf.layers.conv2d
层上。如何解决此OOM问题?

shuffle
的文档中,有人写道
。。。用缓冲区大小的元素填充缓冲区…
因此,在您的情况下,数据集将至少占用54368*40*3*64*32*2位,约为3.4GB。只是为了洗牌操作。你在使用4GB的gpu吗


另一件事是预取缓冲区大小应该大于1。为什么要预取1个元素,可能是一批或两批?

shuffle
的文档中写到,
。。。用缓冲区大小的元素填充缓冲区…
因此,在您的情况下,数据集将至少占用54368*40*3*64*32*2位,约为3.4GB。只是为了洗牌操作。你在使用4GB的gpu吗


另一件事是预取缓冲区大小应该大于1。为什么要预取1个元素,可能是一个或两个批次?

是否尝试过使用较小的批次大小?是否将dataset.batch和shuffle放错了位置。张量的大小也有问题。您是从单个numpy数组构建数据集吗?@borarak通过保持相同的批大小,但将较小的数组送入
功能\u占位符
标签\u占位符
中,不会再有错误,因此批大小应该无关紧要。然而,这意味着我只能将部分训练数据读入
特征
标签
,而不是整个数据集。@Sharky你能指定张量大小有什么问题吗?我正在尝试使用
tf.占位符
构建一个数据集,其形状与
功能
标签
相同,但秩0(数据样本数)除外。第一维度应该是批量大小,而不是数据的大小。您是否尝试使用较小的批量大小?您放错了dataset.batch和shuffle。张量的大小也有问题。您是从单个numpy数组构建数据集吗?@borarak通过保持相同的批大小,但将较小的数组送入
功能\u占位符
标签\u占位符
中,不会再有错误,因此批大小应该无关紧要。然而,这意味着我只能将部分训练数据读入
特征
标签
,而不是整个数据集。@Sharky你能指定张量大小有什么问题吗?我正在尝试使用
tf构建一个数据集。占位符
的形状与
特征
标签
相同,除了秩0(数据样本的数量)之外。第一个维度应该是批量大小,而不是数据集的大小