在tf.keras.utils.Sequence中调整批次大小时出错

在tf.keras.utils.Sequence中调整批次大小时出错,keras,sequence,tensorflow2.0,data-generation,Keras,Sequence,Tensorflow2.0,Data Generation,我使用Tensorflow 2.4.1上的tf.keras.utils.Sequence进行研究。我在API文档()中按顺序使用了示例代码,并通过添加on_epoch\u end函数对其进行了微调,以自适应地更改每个epoch上的批大小值 from skimage.io import imread from skimage.transform import resize import numpy as np import random import math # Here, `x_set` i

我使用Tensorflow 2.4.1上的tf.keras.utils.Sequence进行研究。我在API文档()中按顺序使用了示例代码,并通过添加
on_epoch\u end
函数对其进行了微调,以自适应地更改每个epoch上的
批大小

from skimage.io import imread
from skimage.transform import resize
import numpy as np
import random
import math

# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.

class CIFAR10Sequence(tensorflow.keras.utils.Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def on_epoch_end(self):
        print(self.batch_size)
        self.batch_size = int(random.randint(10, 100))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) *
        self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) *
        self.batch_size]

        return np.array([
            resize(imread(file_name), (200, 200))
               for file_name in batch_x]), np.array(batch_y)
然而,在实践中,每个历元的步骤数保持不变,预计将根据批次数而变化。事实上,Tensorflow返回一个警告,通知他们数据不足,并立即停止训练。当初始化批处理大小小于当前批处理大小时,会发生此问题

WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches
这是我的猜测,Tensorflow确实在每个历元后调整了批量大小,但不知何故,模型仍然保持初始值。这个问题在Keras版本1中从未发生过。到目前为止,我还没有解决这个问题的线索。 编辑1:培训数据的数量远大于批次数量