Tensorflow Keras GPU内存溢出与Keras.utils.sequence和generator一起使用

Tensorflow Keras GPU内存溢出与Keras.utils.sequence和generator一起使用

Dataset.py


import os
import random
from skimage import io
import cv2
from skimage.transform import resize
import numpy as np
import tensorflow as tf

import keras
import Augmentor

def iter_sequence_infinite(seq):
    """Iterate indefinitely over a Sequence.
    # Arguments
        seq: Sequence object
    # Returns
        Generator yielding batches.
    while True:
        for item in seq:
            yield item

# data generator class
class DataGenerator(keras.utils.Sequence):
    def __init__(self, ids, imgs_dir, masks_dir, batch_size=10, img_size=128, n_classes=1, n_channels=3, shuffle=True):
        self.id_names = ids
        self.indexes = np.arange(len(self.id_names))
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.n_classes = n_classes
        self.n_channels = n_channels
        self.shuffle = shuffle

    # for printing the statistics of the function
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.id_names))
        if self.shuffle == True:

    def __data_generation__(self, id_name):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        img_path = os.path.join(self.imgs_dir, id_name)  # polyp segmentation/images/id_name.jpg
        mask_path = os.path.join(self.masks_dir, id_name) # polyp segmenatation/masks/id_name.jpg

        img = io.imread(img_path)
        mask = cv2.imread(mask_path)

        p = Augmentor.DataPipeline([[img, mask]])
        p.resize(probability=1.0, width=self.img_size, height=self.img_size)
        p.rotate_without_crop(probability=0.3, max_left_rotation=10, max_right_rotation=10)
        #p.random_distortion(probability=0.3, grid_height=10, grid_width=10, magnitude=1)
        p.shear(probability=0.3, max_shear_left=1, max_shear_right=1)
        #p.skew_tilt(probability=0.3, magnitude=0.1)

        sample_p = p.sample(1)
        sample_p = np.array(sample_p).squeeze()

        p_img = sample_p[0]
        p_mask = sample_p[1]
        augmented_mask = (p_mask // 255) * 255  # denoising

        q = Augmentor.DataPipeline([[p_img]])
        q.random_contrast(probability=0.3, min_factor=0.2, max_factor=1.0)  # low to High
        q.random_brightness(probability=0.3, min_factor=0.2, max_factor=1.0)  # dark to bright

        sample_q = q.sample(1)
        sample_q = np.array(sample_q).squeeze()

        image = sample_q
        mask = augmented_mask[::, ::, 0]

        # reading the image from dataset
        ## Reading Image
        image = io.imread(img_path)  # reading image to image vaiable
        image = resize(image, (self.img_size, self.img_size), anti_aliasing=True)  # resizing input image to 128 * 128

        mask = io.imread(mask_path, as_gray=True)  # mask image of same size with all zeros
        mask = resize(mask, (self.img_size, self.img_size), anti_aliasing=True)  # resizing mask to fit the 128 * 128 image
        mask = np.expand_dims(mask, axis=-1)

        # image normalization
        image = image / 255.0
        mask = mask / 255.0

        return image, mask

    def __len__(self):
        "Denotes the number of batches per epoch"
        return int(np.floor(len(self.id_names) / self.batch_size))

    def __getitem__(self, index):  # index : batch no.
        # Generate indexes of the batch
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        batch_ids = [self.id_names[k] for k in indexes]

        imgs = list()
        masks = list()

        for id_name in batch_ids:
            img, mask = self.__data_generation__(id_name)

        imgs = np.array(imgs)
        masks = np.array(masks)

        return imgs, masks  # return batch
import argparse
import logging
import os
import sys
from tqdm import tqdm # progress bar
import numpy as np
import matplotlib.pyplot as plt

from keras import optimizers
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import segmentation_models as sm
from segmentation_models.utils import set_trainable
from dataset import DataGenerator, iter_sequence_infinite

def train_model(model, train_gen, valid_gen, epochs, save_cp=True):
    total_batch_count = 0
    train_img_num = len(train_gen.id_names)
    train_batch_num = len(train_gen)
    train_gen_out = iter_sequence_infinite(train_gen)

    valid_batch_num = len(valid_gen)
    valid_img_num = len(valid_gen.id_names)
    valid_gen_out = iter_sequence_infinite(valid_gen)

    for epoch in range(epochs): # interation as many epochs

        epoch_loss = 0 # loss in this epoch
        epoch_iou = 0
        count = 0

        with tqdm(total=train_img_num, desc=f'Epoch {epoch + 1}/{epochs}',  position=0, leave=True, unit='img') as pbar:  # make progress bar
            for _ in range(train_batch_num):
                batch = next(train_gen_out)
                imgs = batch[0]
                true_masks = batch[1]
                loss, iou = model.train_on_batch(imgs, true_masks)  # value of loss of this batch
                epoch_loss += loss
                epoch_iou += iou

                pbar.set_postfix(**{'Batch loss': loss, 'Batch IoU': iou})  # floating the loss at the post in the pbar

                pbar.update(imgs.shape[0])  # update progress
                count += 1
                total_batch_count += 1

        print( "Epoch : loss: {}, IoU : {}".format(epoch_loss/count, epoch_iou/count))

        # Do validation
        validation_model(model, valid_gen_out, valid_batch_num, valid_img_num)

        if save_cp:
                if not os.path.isdir(checkpoint_dir):
                    logging.info('Created checkpoint directory')
            except OSError:
            model.save_weights(os.path.join(checkpoint_dir , f'CP_epoch{epoch + 1}.h5'))
            logging.info(f'Checkpoint {epoch + 1} saved !')

def validation_model(model, valid_gen_out, valid_batch_num, valid_img_num):
    epoch_loss = 0  # loss in this epoch
    epoch_iou = 0
    count = 0

    with tqdm(total=valid_img_num, desc='Validation round',  position=0, leave=True, unit='img') as pbar:  # make progress bar
        for _ in range(valid_batch_num):
            batch = next(valid_gen_out)
            imgs = batch[0]
            true_masks = batch[1]
            loss, iou = model.test_on_batch(imgs, true_masks)  # value of loss of this batch
            epoch_loss += loss
            epoch_iou += iou

            pbar.set_postfix(**{'Batch, loss': loss, 'Batch IoU': iou})  # floating the loss at the post in the pbar

            pbar.update(imgs.shape[0])  # update progress
            count += 1

    print("Validation loss: {}, IoU: {}".format(epoch_loss / count, epoch_iou / count))
    pred_mask = model.predict(np.expand_dims(imgs[0],0))
    plt.imshow(true_masks[0].squeeze(), cmap="gray")
    plt.imshow(pred_mask.squeeze(), cmap="gray")

def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
    parser.add_argument('-e', '--epochs', metavar='E', type=int, default=50,
                        help='Number of epochs', dest='epochs')
    parser.add_argument('-b', '--batch_size', metavar='B', type=int, nargs='?', default=2,
                        help='Batch size', dest='batch_size')
    parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('-bb', '--backbone', default='resnet50', metavar='FILE',
                        help="backcone name")
    parser.add_argument('-w', '--weight', dest='load', type=str, default=False,
                        help='Load model from a .h5 file')
    parser.add_argument('-s', '--resizing', dest='resizing', type=int, default=384,
                        help='Downscaling factor of the images')
    parser.add_argument('-v', '--validation', dest='val', type=float, default=20.0,
                        help='Percent of the data that is used as validation (0-100)')

    return parser.parse_args()

if __name__ == '__main__':
    img_dir = './data/train/imgs/'  # ./data/train/imgs/CVC_Original/'
    mask_dir = './data/train/masks/'  # ./data/train/masks/CVC_Ground Truth/'
    checkpoint_dir = './checkpoints'
    args = get_args()

    # train path
    train_ids = os.listdir(img_dir)
    # Validation Data Size
    n_val = int(len(train_ids) * args.val/100)  # size of validation set

    valid_ids = train_ids[:n_val]  # list of image ids used for validation of result 0 to 9
    train_ids = train_ids[n_val:]  # list of image ids used for training dataset
    # print(valid_ids, "\n\n")
    print("training_size: ", len(train_ids), "validation_size: ", len(valid_ids))

    train_gen = DataGenerator(train_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)
    valid_gen = DataGenerator(valid_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)

    print("total training batches: ", len(train_gen))
    print("total validaton batches: ", len(valid_gen))
    train_steps = len(train_ids) // args.batch_size
    valid_steps = len(valid_ids) // args.batch_size

    # define model
    model = sm.Unet(args.backbone, encoder_weights='imagenet')

    optimizer = optimizers.Adam(lr=args.lr, decay=1e-4)
        #        "Adam",
        loss=sm.losses.bce_dice_loss,  # sm.losses.bce_jaccard_loss, # sm.losses.binary_crossentropy,

    callbacks = [
        EarlyStopping(patience=6, verbose=1),
        ReduceLROnPlateau(factor=0.1, patience=3, min_lr=1e-7, verbose=1),
        ModelCheckpoint('./weights.Epoch{epoch:02d}-Loss{loss:.3f}-VIou{val_iou_score:.3f}.h5', verbose=1,
                        monitor='val_accuracy', save_best_only=True, save_weights_only=True)

    train_model(model=model, train_gen=train_gen, valid_gen=valid_gen, epochs=args.epochs)

(0) Resource exhausted: OOM when allocating tensor with shape[2,64,96,96] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
     [[{{node decoder_stage2b_bn/FusedBatchNorm}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.


  • 在train.py中,初始化DataGenerator类,该类是在Dataset.py中实现的序列模型

  • 在函数“train\u model”中的第一个函数中,使用函数“iter\u sequence\u infinite”将数据生成器(序列模型)转换为生成器

  • 使用神奇的函数“下一步”,获取批处理


  • 我认为不会有记忆问题,但它已经发生了。 问题是什么?如何解决?

    您使用什么硬件来运行此代码?我已尝试使用RTX2070super和RTX titan。随着年代的增加,这些错误也会发生