Tensorflow tf.data或tf.keras.utils.Sequence。提高tf.data的效率?
我正在尝试开发一个使用自动编码器的图像着色器。有13000张训练图像。如果使用tf.data,每个历元大约需要45分钟;如果使用tf.utils.keras.Sequence,每个历元大约需要25分钟。然而,使用顺序存在死锁的风险。如何改进tf.data?我尝试了一些方法,但似乎没有任何改善 tf.data 1Tensorflow tf.data或tf.keras.utils.Sequence。提高tf.data的效率?,tensorflow,keras,deep-learning,tensorflow-datasets,tf.keras,Tensorflow,Keras,Deep Learning,Tensorflow Datasets,Tf.keras,我正在尝试开发一个使用自动编码器的图像着色器。有13000张训练图像。如果使用tf.data,每个历元大约需要45分钟;如果使用tf.utils.keras.Sequence,每个历元大约需要25分钟。然而,使用顺序存在死锁的风险。如何改进tf.data?我尝试了一些方法,但似乎没有任何改善 tf.data 1 image_path_list = glob.glob('datasets/imagenette/*') data = tf.data.Dataset.list_files(image_
image_path_list = glob.glob('datasets/imagenette/*')
data = tf.data.Dataset.list_files(image_path_list)
def tf_rgb2lab(image):
im_shape = image.shape
[image,] = tf.py_function(color.rgb2lab, [image], [tf.float32])
image.set_shape(im_shape)
return image
def preprocess(path):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [224, 224])
image = tf_rgb2lab(image)
L = image[:,:,0]/100.
ab = image[:,:,1:]/128.
input = tf.stack([L,L,L], axis=2)
return input, ab
train_ds = data.repeat().map(preprocess, AUTOTUNE).batch(32).prefetch(AUTOTUNE)
tf.data 2
AUTOTUNE = tf.data.experimental.AUTOTUNE
def tf_rgb2lab(image):
im_shape = image.shape
[image,] = tf.py_function(color.rgb2lab, [image], [tf.float32])
image.set_shape(im_shape)
return image
def split_for_feed(image):
L = image[:,:,:,0]/100.
ab = image[:,:,:,1:]/128.
input = tf.stack([L,L,L], axis=-1)
return input, ab
def read_images(path):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [224, 224])
image = tf_rgb2lab(image)
return image
data2 = data.repeat().map(read_images, AUTOTUNE).batch(32)
train_ds = data2.map(split_for_feed, AUTOTUNE).prefetch(AUTOTUNE)
序列
class ImageGenerator(tf.keras.utils.Sequence):
def __init__(self, image_filenames, batch_size):
self.image_filenames = image_filenames
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.image_filenames) / self.batch_size)
def __getitem__(self, idx):
batch = self.image_filenames[idx * self.batch_size : (idx + 1) * self.batch_size]
X_batch = []
y_batch = []
for file_name in batch:
file_name = 'datasets/imagenette/' + file_name
try:
color_image = transform.resize(io.imread(file_name),(224,224))
lab_image = color.rgb2lab(color_image)
L = lab_image[:,:,0]/100.
ab = lab_image[:,:,1:]/128.
X_batch.append(np.stack((L,L,L), axis=2))
y_batch.append(ab)
except:
pass
return np.array(X_batch), np.array(y_batch)
如果数据适合内存,请尝试缓存预处理。而不是
train_ds = data.repeat().map(preprocess, AUTOTUNE).batch(32).prefetch(AUTOTUNE)
做
这样,您只需解析每个文件一次,而不是重复解析
如果你想进一步优化流水线,考虑使用它,它可以精确地告诉你在数据集的每一个部分中花费了多少时间,这样你就可以找到瓶颈并解决它。
train_ds = data.map(preprocess, AUTOTUNE).batch(32).cache().repeat().prefetch(AUTOTUNE)