Tensorflow Keras GPU内存溢出与Keras.utils.sequence和generator一起使用

Tensorflow Keras GPU内存溢出与Keras.utils.sequence和generator一起使用,tensorflow,keras,deep-learning,nlp,computer-vision,Tensorflow,Keras,Deep Learning,Nlp,Computer Vision,Dataset.py import os import random from skimage import io import cv2 from skimage.transform import resize import numpy as np import tensorflow as tf import keras import Augmentor def iter_sequence_infinite(seq): """Iterate indefinitely over a S

Dataset.py

import os
import random
from skimage import io
import cv2
from skimage.transform import resize
import numpy as np
import tensorflow as tf

import keras
import Augmentor

def iter_sequence_infinite(seq):
    """Iterate indefinitely over a Sequence.
    # Arguments
        seq: Sequence object
    # Returns
        Generator yielding batches.
    """
    while True:
        for item in seq:
            yield item

# data generator class
class DataGenerator(keras.utils.Sequence):
    def __init__(self, ids, imgs_dir, masks_dir, batch_size=10, img_size=128, n_classes=1, n_channels=3, shuffle=True):
        self.id_names = ids
        self.indexes = np.arange(len(self.id_names))
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.n_classes = n_classes
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    # for printing the statistics of the function
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.id_names))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation__(self, id_name):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        img_path = os.path.join(self.imgs_dir, id_name)  # polyp segmentation/images/id_name.jpg
        mask_path = os.path.join(self.masks_dir, id_name) # polyp segmenatation/masks/id_name.jpg

        img = io.imread(img_path)
        mask = cv2.imread(mask_path)

        p = Augmentor.DataPipeline([[img, mask]])
        p.resize(probability=1.0, width=self.img_size, height=self.img_size)
        p.rotate_without_crop(probability=0.3, max_left_rotation=10, max_right_rotation=10)
        #p.random_distortion(probability=0.3, grid_height=10, grid_width=10, magnitude=1)
        p.shear(probability=0.3, max_shear_left=1, max_shear_right=1)
        #p.skew_tilt(probability=0.3, magnitude=0.1)
        p.flip_random(probability=0.3)

        sample_p = p.sample(1)
        sample_p = np.array(sample_p).squeeze()

        p_img = sample_p[0]
        p_mask = sample_p[1]
        augmented_mask = (p_mask // 255) * 255  # denoising

        q = Augmentor.DataPipeline([[p_img]])
        q.random_contrast(probability=0.3, min_factor=0.2, max_factor=1.0)  # low to High
        q.random_brightness(probability=0.3, min_factor=0.2, max_factor=1.0)  # dark to bright

        sample_q = q.sample(1)
        sample_q = np.array(sample_q).squeeze()

        image = sample_q
        mask = augmented_mask[::, ::, 0]

        """
        # reading the image from dataset
        ## Reading Image
        image = io.imread(img_path)  # reading image to image vaiable
        image = resize(image, (self.img_size, self.img_size), anti_aliasing=True)  # resizing input image to 128 * 128

        mask = io.imread(mask_path, as_gray=True)  # mask image of same size with all zeros
        mask = resize(mask, (self.img_size, self.img_size), anti_aliasing=True)  # resizing mask to fit the 128 * 128 image
        mask = np.expand_dims(mask, axis=-1)
        """

        # image normalization
        image = image / 255.0
        mask = mask / 255.0

        return image, mask

    def __len__(self):
        "Denotes the number of batches per epoch"
        return int(np.floor(len(self.id_names) / self.batch_size))

    def __getitem__(self, index):  # index : batch no.
        # Generate indexes of the batch
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        batch_ids = [self.id_names[k] for k in indexes]

        imgs = list()
        masks = list()

        for id_name in batch_ids:
            img, mask = self.__data_generation__(id_name)
            imgs.append(img)
            masks.append(np.expand_dims(mask,-1))

        imgs = np.array(imgs)
        masks = np.array(masks)

        return imgs, masks  # return batch
import argparse
import logging
import os
import sys
from tqdm import tqdm # progress bar
import numpy as np
import matplotlib.pyplot as plt

from keras import optimizers
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import segmentation_models as sm
from segmentation_models.utils import set_trainable
from dataset import DataGenerator, iter_sequence_infinite



def train_model(model, train_gen, valid_gen, epochs, save_cp=True):
    total_batch_count = 0
    train_img_num = len(train_gen.id_names)
    train_batch_num = len(train_gen)
    train_gen_out = iter_sequence_infinite(train_gen)

    valid_batch_num = len(valid_gen)
    valid_img_num = len(valid_gen.id_names)
    valid_gen_out = iter_sequence_infinite(valid_gen)

    for epoch in range(epochs): # interation as many epochs
        set_trainable(model)

        epoch_loss = 0 # loss in this epoch
        epoch_iou = 0
        count = 0

        with tqdm(total=train_img_num, desc=f'Epoch {epoch + 1}/{epochs}',  position=0, leave=True, unit='img') as pbar:  # make progress bar
            for _ in range(train_batch_num):
                batch = next(train_gen_out)
                imgs = batch[0]
                true_masks = batch[1]
                loss, iou = model.train_on_batch(imgs, true_masks)  # value of loss of this batch
                epoch_loss += loss
                epoch_iou += iou

                pbar.set_postfix(**{'Batch loss': loss, 'Batch IoU': iou})  # floating the loss at the post in the pbar

                pbar.update(imgs.shape[0])  # update progress
                count += 1
                total_batch_count += 1

        train_gen.on_epoch_end()
        print( "Epoch : loss: {}, IoU : {}".format(epoch_loss/count, epoch_iou/count))

        # Do validation
        validation_model(model, valid_gen_out, valid_batch_num, valid_img_num)
        valid_gen.on_epoch_end()

        if save_cp:
            try:
                if not os.path.isdir(checkpoint_dir):
                    os.mkdir(checkpoint_dir)
                    logging.info('Created checkpoint directory')
                else:
                    pass
            except OSError:
                pass
            model.save_weights(os.path.join(checkpoint_dir , f'CP_epoch{epoch + 1}.h5'))
            logging.info(f'Checkpoint {epoch + 1} saved !')

def validation_model(model, valid_gen_out, valid_batch_num, valid_img_num):
    epoch_loss = 0  # loss in this epoch
    epoch_iou = 0
    count = 0

    with tqdm(total=valid_img_num, desc='Validation round',  position=0, leave=True, unit='img') as pbar:  # make progress bar
        for _ in range(valid_batch_num):
            batch = next(valid_gen_out)
            imgs = batch[0]
            true_masks = batch[1]
            loss, iou = model.test_on_batch(imgs, true_masks)  # value of loss of this batch
            epoch_loss += loss
            epoch_iou += iou

            pbar.set_postfix(**{'Batch, loss': loss, 'Batch IoU': iou})  # floating the loss at the post in the pbar

            pbar.update(imgs.shape[0])  # update progress
            count += 1

    print("Validation loss: {}, IoU: {}".format(epoch_loss / count, epoch_iou / count))
    pred_mask = model.predict(np.expand_dims(imgs[0],0))
    plt.subplot(131)
    plt.imshow(imgs[0])
    plt.subplot(132)
    plt.imshow(true_masks[0].squeeze(), cmap="gray")
    plt.subplot(133)
    plt.imshow(pred_mask.squeeze(), cmap="gray")
    plt.show()
    print()


def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-e', '--epochs', metavar='E', type=int, default=50,
                        help='Number of epochs', dest='epochs')
    parser.add_argument('-b', '--batch_size', metavar='B', type=int, nargs='?', default=2,
                        help='Batch size', dest='batch_size')
    parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('-bb', '--backbone', default='resnet50', metavar='FILE',
                        help="backcone name")
    parser.add_argument('-w', '--weight', dest='load', type=str, default=False,
                        help='Load model from a .h5 file')
    parser.add_argument('-s', '--resizing', dest='resizing', type=int, default=384,
                        help='Downscaling factor of the images')
    parser.add_argument('-v', '--validation', dest='val', type=float, default=20.0,
                        help='Percent of the data that is used as validation (0-100)')

    return parser.parse_args()


if __name__ == '__main__':
    img_dir = './data/train/imgs/'  # ./data/train/imgs/CVC_Original/'
    mask_dir = './data/train/masks/'  # ./data/train/masks/CVC_Ground Truth/'
    checkpoint_dir = './checkpoints'
    args = get_args()

    # train path
    train_ids = os.listdir(img_dir)
    # Validation Data Size
    n_val = int(len(train_ids) * args.val/100)  # size of validation set


    valid_ids = train_ids[:n_val]  # list of image ids used for validation of result 0 to 9
    train_ids = train_ids[n_val:]  # list of image ids used for training dataset
    # print(valid_ids, "\n\n")
    print("training_size: ", len(train_ids), "validation_size: ", len(valid_ids))

    train_gen = DataGenerator(train_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)
    valid_gen = DataGenerator(valid_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)

    print("total training batches: ", len(train_gen))
    print("total validaton batches: ", len(valid_gen))
    train_steps = len(train_ids) // args.batch_size
    valid_steps = len(valid_ids) // args.batch_size

    # define model
    model = sm.Unet(args.backbone, encoder_weights='imagenet')

    optimizer = optimizers.Adam(lr=args.lr, decay=1e-4)
    model.compile(
        optimizer=optimizer,
        #        "Adam",
        loss=sm.losses.bce_dice_loss,  # sm.losses.bce_jaccard_loss, # sm.losses.binary_crossentropy,
        metrics=[sm.metrics.iou_score],
    )
    #model.summary()

    callbacks = [
        EarlyStopping(patience=6, verbose=1),
        ReduceLROnPlateau(factor=0.1, patience=3, min_lr=1e-7, verbose=1),
        ModelCheckpoint('./weights.Epoch{epoch:02d}-Loss{loss:.3f}-VIou{val_iou_score:.3f}.h5', verbose=1,
                        monitor='val_accuracy', save_best_only=True, save_weights_only=True)
                ]


    train_model(model=model, train_gen=train_gen, valid_gen=valid_gen, epochs=args.epochs)
train.py

import os
import random
from skimage import io
import cv2
from skimage.transform import resize
import numpy as np
import tensorflow as tf

import keras
import Augmentor

def iter_sequence_infinite(seq):
    """Iterate indefinitely over a Sequence.
    # Arguments
        seq: Sequence object
    # Returns
        Generator yielding batches.
    """
    while True:
        for item in seq:
            yield item

# data generator class
class DataGenerator(keras.utils.Sequence):
    def __init__(self, ids, imgs_dir, masks_dir, batch_size=10, img_size=128, n_classes=1, n_channels=3, shuffle=True):
        self.id_names = ids
        self.indexes = np.arange(len(self.id_names))
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.n_classes = n_classes
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    # for printing the statistics of the function
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.id_names))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation__(self, id_name):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        img_path = os.path.join(self.imgs_dir, id_name)  # polyp segmentation/images/id_name.jpg
        mask_path = os.path.join(self.masks_dir, id_name) # polyp segmenatation/masks/id_name.jpg

        img = io.imread(img_path)
        mask = cv2.imread(mask_path)

        p = Augmentor.DataPipeline([[img, mask]])
        p.resize(probability=1.0, width=self.img_size, height=self.img_size)
        p.rotate_without_crop(probability=0.3, max_left_rotation=10, max_right_rotation=10)
        #p.random_distortion(probability=0.3, grid_height=10, grid_width=10, magnitude=1)
        p.shear(probability=0.3, max_shear_left=1, max_shear_right=1)
        #p.skew_tilt(probability=0.3, magnitude=0.1)
        p.flip_random(probability=0.3)

        sample_p = p.sample(1)
        sample_p = np.array(sample_p).squeeze()

        p_img = sample_p[0]
        p_mask = sample_p[1]
        augmented_mask = (p_mask // 255) * 255  # denoising

        q = Augmentor.DataPipeline([[p_img]])
        q.random_contrast(probability=0.3, min_factor=0.2, max_factor=1.0)  # low to High
        q.random_brightness(probability=0.3, min_factor=0.2, max_factor=1.0)  # dark to bright

        sample_q = q.sample(1)
        sample_q = np.array(sample_q).squeeze()

        image = sample_q
        mask = augmented_mask[::, ::, 0]

        """
        # reading the image from dataset
        ## Reading Image
        image = io.imread(img_path)  # reading image to image vaiable
        image = resize(image, (self.img_size, self.img_size), anti_aliasing=True)  # resizing input image to 128 * 128

        mask = io.imread(mask_path, as_gray=True)  # mask image of same size with all zeros
        mask = resize(mask, (self.img_size, self.img_size), anti_aliasing=True)  # resizing mask to fit the 128 * 128 image
        mask = np.expand_dims(mask, axis=-1)
        """

        # image normalization
        image = image / 255.0
        mask = mask / 255.0

        return image, mask

    def __len__(self):
        "Denotes the number of batches per epoch"
        return int(np.floor(len(self.id_names) / self.batch_size))

    def __getitem__(self, index):  # index : batch no.
        # Generate indexes of the batch
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        batch_ids = [self.id_names[k] for k in indexes]

        imgs = list()
        masks = list()

        for id_name in batch_ids:
            img, mask = self.__data_generation__(id_name)
            imgs.append(img)
            masks.append(np.expand_dims(mask,-1))

        imgs = np.array(imgs)
        masks = np.array(masks)

        return imgs, masks  # return batch
import argparse
import logging
import os
import sys
from tqdm import tqdm # progress bar
import numpy as np
import matplotlib.pyplot as plt

from keras import optimizers
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import segmentation_models as sm
from segmentation_models.utils import set_trainable
from dataset import DataGenerator, iter_sequence_infinite



def train_model(model, train_gen, valid_gen, epochs, save_cp=True):
    total_batch_count = 0
    train_img_num = len(train_gen.id_names)
    train_batch_num = len(train_gen)
    train_gen_out = iter_sequence_infinite(train_gen)

    valid_batch_num = len(valid_gen)
    valid_img_num = len(valid_gen.id_names)
    valid_gen_out = iter_sequence_infinite(valid_gen)

    for epoch in range(epochs): # interation as many epochs
        set_trainable(model)

        epoch_loss = 0 # loss in this epoch
        epoch_iou = 0
        count = 0

        with tqdm(total=train_img_num, desc=f'Epoch {epoch + 1}/{epochs}',  position=0, leave=True, unit='img') as pbar:  # make progress bar
            for _ in range(train_batch_num):
                batch = next(train_gen_out)
                imgs = batch[0]
                true_masks = batch[1]
                loss, iou = model.train_on_batch(imgs, true_masks)  # value of loss of this batch
                epoch_loss += loss
                epoch_iou += iou

                pbar.set_postfix(**{'Batch loss': loss, 'Batch IoU': iou})  # floating the loss at the post in the pbar

                pbar.update(imgs.shape[0])  # update progress
                count += 1
                total_batch_count += 1

        train_gen.on_epoch_end()
        print( "Epoch : loss: {}, IoU : {}".format(epoch_loss/count, epoch_iou/count))

        # Do validation
        validation_model(model, valid_gen_out, valid_batch_num, valid_img_num)
        valid_gen.on_epoch_end()

        if save_cp:
            try:
                if not os.path.isdir(checkpoint_dir):
                    os.mkdir(checkpoint_dir)
                    logging.info('Created checkpoint directory')
                else:
                    pass
            except OSError:
                pass
            model.save_weights(os.path.join(checkpoint_dir , f'CP_epoch{epoch + 1}.h5'))
            logging.info(f'Checkpoint {epoch + 1} saved !')

def validation_model(model, valid_gen_out, valid_batch_num, valid_img_num):
    epoch_loss = 0  # loss in this epoch
    epoch_iou = 0
    count = 0

    with tqdm(total=valid_img_num, desc='Validation round',  position=0, leave=True, unit='img') as pbar:  # make progress bar
        for _ in range(valid_batch_num):
            batch = next(valid_gen_out)
            imgs = batch[0]
            true_masks = batch[1]
            loss, iou = model.test_on_batch(imgs, true_masks)  # value of loss of this batch
            epoch_loss += loss
            epoch_iou += iou

            pbar.set_postfix(**{'Batch, loss': loss, 'Batch IoU': iou})  # floating the loss at the post in the pbar

            pbar.update(imgs.shape[0])  # update progress
            count += 1

    print("Validation loss: {}, IoU: {}".format(epoch_loss / count, epoch_iou / count))
    pred_mask = model.predict(np.expand_dims(imgs[0],0))
    plt.subplot(131)
    plt.imshow(imgs[0])
    plt.subplot(132)
    plt.imshow(true_masks[0].squeeze(), cmap="gray")
    plt.subplot(133)
    plt.imshow(pred_mask.squeeze(), cmap="gray")
    plt.show()
    print()


def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-e', '--epochs', metavar='E', type=int, default=50,
                        help='Number of epochs', dest='epochs')
    parser.add_argument('-b', '--batch_size', metavar='B', type=int, nargs='?', default=2,
                        help='Batch size', dest='batch_size')
    parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('-bb', '--backbone', default='resnet50', metavar='FILE',
                        help="backcone name")
    parser.add_argument('-w', '--weight', dest='load', type=str, default=False,
                        help='Load model from a .h5 file')
    parser.add_argument('-s', '--resizing', dest='resizing', type=int, default=384,
                        help='Downscaling factor of the images')
    parser.add_argument('-v', '--validation', dest='val', type=float, default=20.0,
                        help='Percent of the data that is used as validation (0-100)')

    return parser.parse_args()


if __name__ == '__main__':
    img_dir = './data/train/imgs/'  # ./data/train/imgs/CVC_Original/'
    mask_dir = './data/train/masks/'  # ./data/train/masks/CVC_Ground Truth/'
    checkpoint_dir = './checkpoints'
    args = get_args()

    # train path
    train_ids = os.listdir(img_dir)
    # Validation Data Size
    n_val = int(len(train_ids) * args.val/100)  # size of validation set


    valid_ids = train_ids[:n_val]  # list of image ids used for validation of result 0 to 9
    train_ids = train_ids[n_val:]  # list of image ids used for training dataset
    # print(valid_ids, "\n\n")
    print("training_size: ", len(train_ids), "validation_size: ", len(valid_ids))

    train_gen = DataGenerator(train_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)
    valid_gen = DataGenerator(valid_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)

    print("total training batches: ", len(train_gen))
    print("total validaton batches: ", len(valid_gen))
    train_steps = len(train_ids) // args.batch_size
    valid_steps = len(valid_ids) // args.batch_size

    # define model
    model = sm.Unet(args.backbone, encoder_weights='imagenet')

    optimizer = optimizers.Adam(lr=args.lr, decay=1e-4)
    model.compile(
        optimizer=optimizer,
        #        "Adam",
        loss=sm.losses.bce_dice_loss,  # sm.losses.bce_jaccard_loss, # sm.losses.binary_crossentropy,
        metrics=[sm.metrics.iou_score],
    )
    #model.summary()

    callbacks = [
        EarlyStopping(patience=6, verbose=1),
        ReduceLROnPlateau(factor=0.1, patience=3, min_lr=1e-7, verbose=1),
        ModelCheckpoint('./weights.Epoch{epoch:02d}-Loss{loss:.3f}-VIou{val_iou_score:.3f}.h5', verbose=1,
                        monitor='val_accuracy', save_best_only=True, save_weights_only=True)
                ]


    train_model(model=model, train_gen=train_gen, valid_gen=valid_gen, epochs=args.epochs)
当我试着运行这段代码时,有些阶段进展顺利,但在20个阶段中,出现了gpu内存溢出错误,如下所示

(0) Resource exhausted: OOM when allocating tensor with shape[2,64,96,96] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
     [[{{node decoder_stage2b_bn/FusedBatchNorm}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
因此,我认为这是因为数据生成

此代码按此顺序生成批

  • 在train.py中,初始化DataGenerator类,该类是在Dataset.py中实现的序列模型

    import os
    import random
    from skimage import io
    import cv2
    from skimage.transform import resize
    import numpy as np
    import tensorflow as tf
    
    import keras
    import Augmentor
    
    def iter_sequence_infinite(seq):
        """Iterate indefinitely over a Sequence.
        # Arguments
            seq: Sequence object
        # Returns
            Generator yielding batches.
        """
        while True:
            for item in seq:
                yield item
    
    # data generator class
    class DataGenerator(keras.utils.Sequence):
        def __init__(self, ids, imgs_dir, masks_dir, batch_size=10, img_size=128, n_classes=1, n_channels=3, shuffle=True):
            self.id_names = ids
            self.indexes = np.arange(len(self.id_names))
            self.imgs_dir = imgs_dir
            self.masks_dir = masks_dir
            self.batch_size = batch_size
            self.img_size = img_size
            self.n_classes = n_classes
            self.n_channels = n_channels
            self.shuffle = shuffle
            self.on_epoch_end()
    
        # for printing the statistics of the function
        def on_epoch_end(self):
            'Updates indexes after each epoch'
            self.indexes = np.arange(len(self.id_names))
            if self.shuffle == True:
                np.random.shuffle(self.indexes)
    
        def __data_generation__(self, id_name):
            'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
            # Initialization
            img_path = os.path.join(self.imgs_dir, id_name)  # polyp segmentation/images/id_name.jpg
            mask_path = os.path.join(self.masks_dir, id_name) # polyp segmenatation/masks/id_name.jpg
    
            img = io.imread(img_path)
            mask = cv2.imread(mask_path)
    
            p = Augmentor.DataPipeline([[img, mask]])
            p.resize(probability=1.0, width=self.img_size, height=self.img_size)
            p.rotate_without_crop(probability=0.3, max_left_rotation=10, max_right_rotation=10)
            #p.random_distortion(probability=0.3, grid_height=10, grid_width=10, magnitude=1)
            p.shear(probability=0.3, max_shear_left=1, max_shear_right=1)
            #p.skew_tilt(probability=0.3, magnitude=0.1)
            p.flip_random(probability=0.3)
    
            sample_p = p.sample(1)
            sample_p = np.array(sample_p).squeeze()
    
            p_img = sample_p[0]
            p_mask = sample_p[1]
            augmented_mask = (p_mask // 255) * 255  # denoising
    
            q = Augmentor.DataPipeline([[p_img]])
            q.random_contrast(probability=0.3, min_factor=0.2, max_factor=1.0)  # low to High
            q.random_brightness(probability=0.3, min_factor=0.2, max_factor=1.0)  # dark to bright
    
            sample_q = q.sample(1)
            sample_q = np.array(sample_q).squeeze()
    
            image = sample_q
            mask = augmented_mask[::, ::, 0]
    
            """
            # reading the image from dataset
            ## Reading Image
            image = io.imread(img_path)  # reading image to image vaiable
            image = resize(image, (self.img_size, self.img_size), anti_aliasing=True)  # resizing input image to 128 * 128
    
            mask = io.imread(mask_path, as_gray=True)  # mask image of same size with all zeros
            mask = resize(mask, (self.img_size, self.img_size), anti_aliasing=True)  # resizing mask to fit the 128 * 128 image
            mask = np.expand_dims(mask, axis=-1)
            """
    
            # image normalization
            image = image / 255.0
            mask = mask / 255.0
    
            return image, mask
    
        def __len__(self):
            "Denotes the number of batches per epoch"
            return int(np.floor(len(self.id_names) / self.batch_size))
    
        def __getitem__(self, index):  # index : batch no.
            # Generate indexes of the batch
            # Generate indexes of the batch
            indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
            batch_ids = [self.id_names[k] for k in indexes]
    
            imgs = list()
            masks = list()
    
            for id_name in batch_ids:
                img, mask = self.__data_generation__(id_name)
                imgs.append(img)
                masks.append(np.expand_dims(mask,-1))
    
            imgs = np.array(imgs)
            masks = np.array(masks)
    
            return imgs, masks  # return batch
    
    import argparse
    import logging
    import os
    import sys
    from tqdm import tqdm # progress bar
    import numpy as np
    import matplotlib.pyplot as plt
    
    from keras import optimizers
    from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
    import segmentation_models as sm
    from segmentation_models.utils import set_trainable
    from dataset import DataGenerator, iter_sequence_infinite
    
    
    
    def train_model(model, train_gen, valid_gen, epochs, save_cp=True):
        total_batch_count = 0
        train_img_num = len(train_gen.id_names)
        train_batch_num = len(train_gen)
        train_gen_out = iter_sequence_infinite(train_gen)
    
        valid_batch_num = len(valid_gen)
        valid_img_num = len(valid_gen.id_names)
        valid_gen_out = iter_sequence_infinite(valid_gen)
    
        for epoch in range(epochs): # interation as many epochs
            set_trainable(model)
    
            epoch_loss = 0 # loss in this epoch
            epoch_iou = 0
            count = 0
    
            with tqdm(total=train_img_num, desc=f'Epoch {epoch + 1}/{epochs}',  position=0, leave=True, unit='img') as pbar:  # make progress bar
                for _ in range(train_batch_num):
                    batch = next(train_gen_out)
                    imgs = batch[0]
                    true_masks = batch[1]
                    loss, iou = model.train_on_batch(imgs, true_masks)  # value of loss of this batch
                    epoch_loss += loss
                    epoch_iou += iou
    
                    pbar.set_postfix(**{'Batch loss': loss, 'Batch IoU': iou})  # floating the loss at the post in the pbar
    
                    pbar.update(imgs.shape[0])  # update progress
                    count += 1
                    total_batch_count += 1
    
            train_gen.on_epoch_end()
            print( "Epoch : loss: {}, IoU : {}".format(epoch_loss/count, epoch_iou/count))
    
            # Do validation
            validation_model(model, valid_gen_out, valid_batch_num, valid_img_num)
            valid_gen.on_epoch_end()
    
            if save_cp:
                try:
                    if not os.path.isdir(checkpoint_dir):
                        os.mkdir(checkpoint_dir)
                        logging.info('Created checkpoint directory')
                    else:
                        pass
                except OSError:
                    pass
                model.save_weights(os.path.join(checkpoint_dir , f'CP_epoch{epoch + 1}.h5'))
                logging.info(f'Checkpoint {epoch + 1} saved !')
    
    def validation_model(model, valid_gen_out, valid_batch_num, valid_img_num):
        epoch_loss = 0  # loss in this epoch
        epoch_iou = 0
        count = 0
    
        with tqdm(total=valid_img_num, desc='Validation round',  position=0, leave=True, unit='img') as pbar:  # make progress bar
            for _ in range(valid_batch_num):
                batch = next(valid_gen_out)
                imgs = batch[0]
                true_masks = batch[1]
                loss, iou = model.test_on_batch(imgs, true_masks)  # value of loss of this batch
                epoch_loss += loss
                epoch_iou += iou
    
                pbar.set_postfix(**{'Batch, loss': loss, 'Batch IoU': iou})  # floating the loss at the post in the pbar
    
                pbar.update(imgs.shape[0])  # update progress
                count += 1
    
        print("Validation loss: {}, IoU: {}".format(epoch_loss / count, epoch_iou / count))
        pred_mask = model.predict(np.expand_dims(imgs[0],0))
        plt.subplot(131)
        plt.imshow(imgs[0])
        plt.subplot(132)
        plt.imshow(true_masks[0].squeeze(), cmap="gray")
        plt.subplot(133)
        plt.imshow(pred_mask.squeeze(), cmap="gray")
        plt.show()
        print()
    
    
    def get_args():
        parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
                                         formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument('-e', '--epochs', metavar='E', type=int, default=50,
                            help='Number of epochs', dest='epochs')
        parser.add_argument('-b', '--batch_size', metavar='B', type=int, nargs='?', default=2,
                            help='Batch size', dest='batch_size')
        parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=1e-5,
                            help='Learning rate', dest='lr')
        parser.add_argument('-bb', '--backbone', default='resnet50', metavar='FILE',
                            help="backcone name")
        parser.add_argument('-w', '--weight', dest='load', type=str, default=False,
                            help='Load model from a .h5 file')
        parser.add_argument('-s', '--resizing', dest='resizing', type=int, default=384,
                            help='Downscaling factor of the images')
        parser.add_argument('-v', '--validation', dest='val', type=float, default=20.0,
                            help='Percent of the data that is used as validation (0-100)')
    
        return parser.parse_args()
    
    
    if __name__ == '__main__':
        img_dir = './data/train/imgs/'  # ./data/train/imgs/CVC_Original/'
        mask_dir = './data/train/masks/'  # ./data/train/masks/CVC_Ground Truth/'
        checkpoint_dir = './checkpoints'
        args = get_args()
    
        # train path
        train_ids = os.listdir(img_dir)
        # Validation Data Size
        n_val = int(len(train_ids) * args.val/100)  # size of validation set
    
    
        valid_ids = train_ids[:n_val]  # list of image ids used for validation of result 0 to 9
        train_ids = train_ids[n_val:]  # list of image ids used for training dataset
        # print(valid_ids, "\n\n")
        print("training_size: ", len(train_ids), "validation_size: ", len(valid_ids))
    
        train_gen = DataGenerator(train_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)
        valid_gen = DataGenerator(valid_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)
    
        print("total training batches: ", len(train_gen))
        print("total validaton batches: ", len(valid_gen))
        train_steps = len(train_ids) // args.batch_size
        valid_steps = len(valid_ids) // args.batch_size
    
        # define model
        model = sm.Unet(args.backbone, encoder_weights='imagenet')
    
        optimizer = optimizers.Adam(lr=args.lr, decay=1e-4)
        model.compile(
            optimizer=optimizer,
            #        "Adam",
            loss=sm.losses.bce_dice_loss,  # sm.losses.bce_jaccard_loss, # sm.losses.binary_crossentropy,
            metrics=[sm.metrics.iou_score],
        )
        #model.summary()
    
        callbacks = [
            EarlyStopping(patience=6, verbose=1),
            ReduceLROnPlateau(factor=0.1, patience=3, min_lr=1e-7, verbose=1),
            ModelCheckpoint('./weights.Epoch{epoch:02d}-Loss{loss:.3f}-VIou{val_iou_score:.3f}.h5', verbose=1,
                            monitor='val_accuracy', save_best_only=True, save_weights_only=True)
                    ]
    
    
        train_model(model=model, train_gen=train_gen, valid_gen=valid_gen, epochs=args.epochs)
    
    train\u gen=DataGenerator(train\u id、img\u dir、mask\u dir、img\u size=args.resizing、batch\u size=args.batch\u size)

    valid\u gen=DataGenerator(valid\u id、img\u dir、mask\u dir、img\u size=args.resizing、batch\u size=args.batch\u size)

  • 在函数“train\u model”中的第一个函数中,使用函数“iter\u sequence\u infinite”将数据生成器(序列模型)转换为生成器

    train\u gen\u out=iter\u sequence\u infinite(train\u gen)

    valid\u gen\u out=iter\u sequence\u infinite(valid\u gen)

  • 使用神奇的函数“下一步”,获取批处理

    batch=next(训练发电机输出)

  • 我认为不会有记忆问题,但它已经发生了。 问题是什么?如何解决?
    谢谢。

    您使用什么硬件来运行此代码?我已尝试使用RTX2070super和RTX titan。随着年代的增加,这些错误也会发生