TensorFlow将数据加载到tf.Dataset中花费的时间太长
我正在使用TensorFlow 1.9来训练一个图像数据集,它太大了,无法从我的硬盘加载到RAM中。因此,我在硬盘上将数据集分成两半。我想知道在整个数据集上进行训练最有效的方法是什么 我的GPU有3 GB内存,我的RAM有32 GB内存。每一半数据集的大小为20 GB。我的硬盘有足够的可用空间(超过1 TB) 我的尝试如下。我创建了一个可初始化的TensorFlow将数据加载到tf.Dataset中花费的时间太长,tensorflow,Tensorflow,我正在使用TensorFlow 1.9来训练一个图像数据集,它太大了,无法从我的硬盘加载到RAM中。因此,我在硬盘上将数据集分成两半。我想知道在整个数据集上进行训练最有效的方法是什么 我的GPU有3 GB内存,我的RAM有32 GB内存。每一半数据集的大小为20 GB。我的硬盘有足够的可用空间(超过1 TB) 我的尝试如下。我创建了一个可初始化的tf.Dataset,然后在每个历元上初始化它两次:对数据集的每一半初始化一次。通过这种方式,每个历元都可以看到整个数据集,但每次只需将其中的一半加载到
tf.Dataset
,然后在每个历元上初始化它两次:对数据集的每一半初始化一次。通过这种方式,每个历元都可以看到整个数据集,但每次只需将其中的一半加载到RAM中
但是,这是非常缓慢的,因为从硬盘加载数据需要很长时间,而且每次使用此数据初始化数据集也需要相当长的时间
有没有更有效的方法
在加载数据集的另一半之前,我尝试对数据集的每一半进行多个历元的训练,这要快得多,但这会使验证数据的性能差得多。据推测,这是因为模型在每一半上都过拟合,然后没有推广到另一半的数据
在下面的代码中,我创建并保存一些测试数据,然后按照上面的描述加载这些数据。加载每半个数据集的时间约为5秒,使用此数据初始化数据集的时间约为1秒。这看起来可能只是一个小数目,但它在多个时代累积起来。事实上,我的电脑加载数据的时间几乎和实际训练数据的时间一样多
import tensorflow as tf
import numpy as np
import time
# Create and save 2 datasets of test NumPy data
dataset_num_elements = 100000
element_dim = 10000
batch_size = 50
test_data = np.zeros([2, int(dataset_num_elements * 0.5), element_dim], dtype=np.float32)
np.savez('test_data_1.npz', x=test_data[0])
np.savez('test_data_2.npz', x=test_data[1])
# Create the TensorFlow dataset
data_placeholder = tf.placeholder(tf.float32, [int(dataset_num_elements * 0.5), element_dim])
dataset = tf.data.Dataset.from_tensor_slices(data_placeholder)
dataset = dataset.shuffle(buffer_size=dataset_num_elements)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.prefetch(1)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
init_op = iterator.initializer
num_batches = int(dataset_num_elements / batch_size)
with tf.Session() as sess:
while True:
for dataset_section in range(2):
# Load the data from the hard drive
t1 = time.time()
print('Loading')
loaded_data = np.load('test_data_' + str(dataset_section + 1) + '.npz')
x = loaded_data['x']
print('Loaded')
t2 = time.time()
loading_time = t2 - t1
print('Loading time = ' + str(loading_time))
# Initialize the dataset with this loaded data
t1 = time.time()
sess.run(init_op, feed_dict={data_placeholder: x})
t2 = time.time()
initialization_time = t2 - t1
print('Initialization time = ' + str(initialization_time))
# Read the data in batches
for i in range(num_batches):
x = sess.run(next_element)
提要不是输入数据的有效方式。您可以这样输入数据:
- 尽可能使用轻量级提要
- 使用多线程进行读取和预处理
- 为训练预取数据