Python 如何在第一个历元中直接使用map函数生成的转换,而不是在每个历元中执行map函数?

Python 如何在第一个历元中直接使用map函数生成的转换,而不是在每个历元中执行map函数?,python,tensorflow,keras,tensorflow2.0,tensorflow-datasets,Python,Tensorflow,Keras,Tensorflow2.0,Tensorflow Datasets,我想使用map函数对数据集进行一些转换。然而,我发现map函数在每个历元中都被执行。 map函数是否可能仅在第一个历元中执行,而后续历元直接使用在第一个历元中生成的变换?如果使用Tensorflow设置种子并使用tf.image应用变换,则历元之间的随机变换将是一致的 import tensorflow as tf from skimage import data tf.random.set_seed(42) import matplotlib.pyplot as plt def transf

我想使用map函数对数据集进行一些转换。然而,我发现map函数在每个历元中都被执行。
map函数是否可能仅在第一个历元中执行,而后续历元直接使用在第一个历元中生成的变换?

如果使用Tensorflow设置种子并使用tf.image应用变换,则历元之间的随机变换将是一致的

import tensorflow as tf
from skimage import data
tf.random.set_seed(42)
import matplotlib.pyplot as plt

def transform(image):
    image = tf.image.random_hue(image, 0.5, 1.)
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    return image

X = tf.stack([data.chelsea() for i in range(4)])
ds = tf.data.Dataset.from_tensor_slices(X).map(transform)

inputs = [[], []]

for epoch in range(2):
    for sample in ds:
        inputs[epoch].append(sample)

inputs_paired = [i for s in inputs for i in s]

fig = plt.figure(figsize=(16, 8))
plt.subplots_adjust(wspace=.1, hspace=.1)
for i in range(8):
    ax = plt.subplot(2, 4, i + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(inputs_paired[i])
plt.show()
顶部是第一个纪元,底部是第二个纪元