Python DCGAN调试。垃圾

Python DCGAN调试。垃圾,python,neural-network,pytorch,generative-adversarial-network,Python,Neural Network,Pytorch,Generative Adversarial Network,简介: import torch import torch.nn as nn import torchvision from torchvision import transforms, datasets import torch.nn.functional as F from torch import optim as optim from torch.utils.tensorboard import SummaryWriter import numpy as np import os i

简介:

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import torch.nn.functional as F
from torch import optim as optim
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
import time


class Discriminator(torch.nn.Module):
    def __init__(self, ndf=16, dropout_value=0.5):  # ndf feature map discriminator
        super().__init__()
        self.ndf = ndf
        self.droupout_value = dropout_value

        self.condi = nn.Sequential(
            nn.Linear(in_features=10, out_features=64 * 64)
        )

        self.hidden0 = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
        )
        self.hidden1 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf, out_channels=self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden2 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 2, out_channels=self.ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden3 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 4, out_channels=self.ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.out = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            torch.nn.Sigmoid()
        )

    def forward(self, x, y):
        y = self.condi(y.view(-1, 10))
        y = y.view(-1, 1, 64, 64)

        x = torch.cat((x, y), dim=1)

        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Generator(torch.nn.Module):
    def __init__(self, n_features=100, ngf=16, c_channels=1, dropout_value=0.5):  # ngf feature map of generator
        super().__init__()
        self.ngf = ngf
        self.n_features = n_features
        self.c_channels = c_channels
        self.droupout_value = dropout_value

        self.hidden0 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.n_features + 10, out_channels=self.ngf * 8,
                               kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.LeakyReLU(0.2)
        )

        self.hidden1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 8, out_channels=self.ngf * 4,
                               kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ngf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 4, out_channels=self.ngf * 2,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 2, out_channels=self.ngf,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.out = nn.Sequential(
            # "out_channels=1" because gray scale
            nn.ConvTranspose2d(in_channels=self.ngf, out_channels=1, kernel_size=4,
                               stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x, y):
        x_cond = torch.cat((x, y), dim=1)  # Combine flatten image with conditional input (class labels)

        x = self.hidden0(x_cond)           # Image goes into a "ConvTranspose2d" layer
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Logger:
    def __init__(self, model_name, model1, model2, m1_optimizer, m2_optimizer, model_parameter, train_loader):
        self.out_dir = "data"
        self.model_name = model_name
        self.train_loader = train_loader
        self.model1 = model1
        self.model2 = model2
        self.model_parameter = model_parameter
        self.m1_optimizer = m1_optimizer
        self.m2_optimizer = m2_optimizer

        # Exclude Epochs of the model name. This make sense e.g. when we stop a training progress and continue later on.
        self.experiment_name = '_'.join("{!s}={!r}".format(k, v) for (k, v) in model_parameter.items())\
            .replace("Epochs" + "=" + str(model_parameter["Epochs"]), "")

        self.d_error = 0
        self.g_error = 0

        self.tb = SummaryWriter(log_dir=str(self.out_dir + "/log/" + self.model_name + "/runs/" + self.experiment_name))

        self.path_image = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/images/{self.experiment_name}')
        self.path_model = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/model/{self.experiment_name}')

        try:
            os.makedirs(self.path_image)
        except Exception as e:
            print("WARNING: ", str(e))

        try:
            os.makedirs(self.path_model)
        except Exception as e:
            print("WARNING: ", str(e))

    def log_graph(self, model1_input, model2_input, model1_label, model2_label):
        self.tb.add_graph(self.model1, input_to_model=(model1_input, model1_label))
        self.tb.add_graph(self.model2, input_to_model=(model2_input, model2_label))

    def log(self, num_epoch, d_error, g_error):
        self.d_error = d_error
        self.g_error = g_error

        self.tb.add_scalar("Discriminator Train Error", self.d_error, num_epoch)
        self.tb.add_scalar("Generator Train Error", self.g_error, num_epoch)

    def log_image(self, images, epoch, batch_num):
        grid = torchvision.utils.make_grid(images)
        torchvision.utils.save_image(grid, f'{self.path_image}\\Epoch_{epoch}_batch_{batch_num}.png')

        self.tb.add_image("Generator Image", grid)

    def log_histogramm(self):
        for name, param in self.model2.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'gen_{name}.grad', param.grad, self.model_parameter["Epochs"])

        for name, param in self.model1.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'dis_{name}.grad', param.grad, self.model_parameter["Epochs"])

    def log_model(self, num_epoch):
        torch.save({
            "epoch": num_epoch,
            "model_generator_state_dict": self.model1.state_dict(),
            "model_discriminator_state_dict": self.model2.state_dict(),
            "optimizer_generator_state_dict":  self.m1_optimizer.state_dict(),
            "optimizer_discriminator_state_dict":  self.m2_optimizer.state_dict(),
        }, str(self.path_model + f'\\{time.time()}_epoch{num_epoch}.pth'))

    def close(self, logger, images, num_epoch,  d_error, g_error):
        logger.log_model(num_epoch)
        logger.log_histogramm()
        logger.log(num_epoch, d_error, g_error)
        self.tb.close()

    def display_stats(self, epoch, batch_num, dis_error, gen_error):
        print(f'Epoch: [{epoch}/{self.model_parameter["Epochs"]}] '
              f'Batch: [{batch_num}/{len(self.train_loader)}] '
              f'Loss_D: {dis_error.data.cpu()}, '
              f'Loss_G: {gen_error.data.cpu()}')


def get_MNIST_dataset(num_workers_loader, model_parameter, out_dir="data"):
    compose = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.CenterCrop((64, 64)),
        transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    dataset = datasets.MNIST(
        root=out_dir,
        train=True,
        download=True,
        transform=compose
    )

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=model_parameter["batch_size"],
                                               num_workers=num_workers_loader,
                                               shuffle=model_parameter["shuffle"])

    return dataset, train_loader


def train_discriminator(p_optimizer, p_noise, p_images, p_fake_target, p_real_target, p_images_labels, p_fake_labels, device):
    p_optimizer.zero_grad()

    # 1.1 Train on real data
    pred_dis_real = discriminator(p_images, p_images_labels)
    error_real = loss(pred_dis_real, p_real_target)

    error_real.backward()

    # 1.2 Train on fake data
    fake_data = generator(p_noise, p_fake_labels).detach()
    fake_data = add_noise_to_image(fake_data, device)
    pred_dis_fake = discriminator(fake_data, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_fake_target)

    error_fake.backward()

    p_optimizer.step()

    return error_fake + error_real


def train_generator(p_optimizer, p_noise, p_real_target, p_fake_labels, device):
    p_optimizer.zero_grad()

    fake_images = generator(p_noise, p_fake_labels)
    fake_images = add_noise_to_image(fake_images, device)
    pred_dis_fake = discriminator(fake_images, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_real_target)  # because
    """
    We use "p_real_target" instead of "p_fake_target" because we want to 
    maximize that the discriminator is wrong.
    """

    error_fake.backward()

    p_optimizer.step()

    return fake_images, pred_dis_fake, error_fake


# TODO change to a Truncated normal distribution
def get_noise(batch_size, n_features=100):
    return torch.FloatTensor(batch_size, n_features, 1, 1).uniform_(-1, 1)


# We flip label of real and fate data. Better gradient flow I have told
def get_real_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.0, 0.2)


def get_fake_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.8, 1.1)


def image_to_vector(images):
    return torch.flatten(images, start_dim=1, end_dim=-1)


def vector_to_image(images):
    return images.view(images.size(0), 1, 28, 28)


def get_rand_labels(batch_size):
    return torch.randint(low=0, high=9, size=(batch_size,))


def load_model(model_load_path):
    if model_load_path:
        checkpoint = torch.load(model_load_path)

        discriminator.load_state_dict(checkpoint["model_discriminator_state_dict"])
        generator.load_state_dict(checkpoint["model_generator_state_dict"])

        dis_opti.load_state_dict(checkpoint["optimizer_discriminator_state_dict"])
        gen_opti.load_state_dict(checkpoint["optimizer_generator_state_dict"])

        return checkpoint["epoch"]

    else:
        return 0


def init_model_optimizer(model_parameter, device):
    # Initialize the Models
    discriminator = Discriminator(ndf=model_parameter["ndf"], dropout_value=model_parameter["dropout"]).to(device)
    generator = Generator(ngf=model_parameter["ngf"], dropout_value=model_parameter["dropout"]).to(device)

    # train
    dis_opti = optim.Adam(discriminator.parameters(), lr=model_parameter["learning_rate_dis"], betas=(0.5, 0.999))
    gen_opti = optim.Adam(generator.parameters(), lr=model_parameter["learning_rate_gen"], betas=(0.5, 0.999))

    return discriminator, generator, dis_opti, gen_opti


def get_hot_vector_encode(labels, device):
    return torch.eye(10)[labels].view(-1, 10, 1, 1).to(device)


def add_noise_to_image(images, device, level_of_noise=0.1):
    return images[0].to(device) + (level_of_noise) * torch.randn(images.shape).to(device)


if __name__ == "__main__":
    # Hyperparameter
    model_parameter = {
        "batch_size": 500,
        "learning_rate_dis": 0.0002,
        "learning_rate_gen": 0.0002,
        "shuffle": False,
        "Epochs": 10,
        "ndf": 64,
        "ngf": 64,
        "dropout": 0.5
    }

    # Parameter
    r_frequent = 10        # How many samples we save for replay per batch (batch_size / r_frequent).
    model_name = "CDCGAN"   # The name of you model e.g. "Gan"
    num_workers_loader = 1  # How many workers should load the data
    sample_save_size = 16   # How many numbers your saved imaged should show
    device = "cuda"         # Which device should be used to train the neural network
    model_load_path = ""    # If set load model instead of training from new
    num_epoch_log = 1       # How frequent you want to log/
    torch.manual_seed(43)   # Sets a seed for torch for reproducibility

    dataset_train, train_loader = get_MNIST_dataset(num_workers_loader, model_parameter)  # Get dataset

    # Initialize the Models and optimizer
    discriminator, generator, dis_opti, gen_opti = init_model_optimizer(model_parameter, device)  # Init model/Optimizer

    start_epoch = load_model(model_load_path)  # when we want to load a model

    # Init Logger
    logger = Logger(model_name, generator, discriminator, gen_opti, dis_opti, model_parameter, train_loader)

    loss = nn.BCELoss()

    images, labels = next(iter(train_loader))  # For logging

    # For testing
    # pred = generator(get_noise(model_parameter["batch_size"]).to(device), get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device))
    # dis = discriminator(images.to(device), get_hot_vector_encode(labels, device))

    logger.log_graph(get_noise(model_parameter["batch_size"]).to(device), images.to(device),
                     get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device),
                     get_hot_vector_encode(labels, device))


    # Array to store
    exp_replay = torch.tensor([]).to(device)

    for num_epoch in range(start_epoch, model_parameter["Epochs"]):
        for batch_num, data_loader in enumerate(train_loader):
            images, labels = data_loader
            images = add_noise_to_image(images, device)  # Add noise to the images

            # 1. Train Discriminator
            dis_error = train_discriminator(
                                            dis_opti,
                                            get_noise(model_parameter["batch_size"]).to(device),
                                            images.to(device),
                                            get_fake_data_target(model_parameter["batch_size"]).to(device),
                                            get_real_data_target(model_parameter["batch_size"]).to(device),
                                            get_hot_vector_encode(labels, device),
                                            get_hot_vector_encode(
                                                get_rand_labels(model_parameter["batch_size"]), device),
                                            device
                                            )

            # 2. Train Generator
            fake_image, pred_dis_fake, gen_error = train_generator(
                                                                  gen_opti,
                                                                  get_noise(model_parameter["batch_size"]).to(device),
                                                                  get_real_data_target(model_parameter["batch_size"]).to(device),
                                                                  get_hot_vector_encode(
                                                                      get_rand_labels(model_parameter["batch_size"]),
                                                                      device),
                                                                  device
                                                                  )


            # Store a random point for experience replay
            perm = torch.randperm(fake_image.size(0))
            r_idx = perm[:max(1, int(model_parameter["batch_size"] / r_frequent))]
            r_samples = add_noise_to_image(fake_image[r_idx], device)
            exp_replay = torch.cat((exp_replay, r_samples), 0).detach()

            if exp_replay.size(0) >= model_parameter["batch_size"]:
                # Train on experienced data
                dis_opti.zero_grad()

                r_label = get_hot_vector_encode(torch.zeros(exp_replay.size(0)).numpy(), device)
                pred_dis_real = discriminator(exp_replay, r_label)
                error_real = loss(pred_dis_real,  get_fake_data_target(exp_replay.size(0)).to(device))

                error_real.backward()

                dis_opti.step()

                print(f'Epoch: [{num_epoch}/{model_parameter["Epochs"]}] '
                      f'Batch: Replay/Experience batch '
                      f'Loss_D: {error_real.data.cpu()}, '
                      )

                exp_replay = torch.tensor([]).to(device)

            logger.display_stats(epoch=num_epoch, batch_num=batch_num, dis_error=dis_error, gen_error=gen_error)

            if batch_num % 100 == 0:
                logger.log_image(fake_image[:sample_save_size], num_epoch, batch_num)

        logger.log(num_epoch, dis_error, gen_error)
        if num_epoch % num_epoch_log == 0:
            logger.log_model(num_epoch)
            logger.log_histogramm()
    logger.close(logger, fake_image[:sample_save_size], num_epoch, dis_error, gen_error)
考虑到我使用的库(PyTorch)在其网站上有教程,我正在尝试让CDCGAN(条件深层卷积生成对抗网络)处理MNIST数据集,这应该相当容易。
但我似乎无法让它工作——它只会产生垃圾,或者模型崩溃,或者两者兼而有之

我的尝试:

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import torch.nn.functional as F
from torch import optim as optim
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
import time


class Discriminator(torch.nn.Module):
    def __init__(self, ndf=16, dropout_value=0.5):  # ndf feature map discriminator
        super().__init__()
        self.ndf = ndf
        self.droupout_value = dropout_value

        self.condi = nn.Sequential(
            nn.Linear(in_features=10, out_features=64 * 64)
        )

        self.hidden0 = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
        )
        self.hidden1 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf, out_channels=self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden2 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 2, out_channels=self.ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden3 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 4, out_channels=self.ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.out = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            torch.nn.Sigmoid()
        )

    def forward(self, x, y):
        y = self.condi(y.view(-1, 10))
        y = y.view(-1, 1, 64, 64)

        x = torch.cat((x, y), dim=1)

        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Generator(torch.nn.Module):
    def __init__(self, n_features=100, ngf=16, c_channels=1, dropout_value=0.5):  # ngf feature map of generator
        super().__init__()
        self.ngf = ngf
        self.n_features = n_features
        self.c_channels = c_channels
        self.droupout_value = dropout_value

        self.hidden0 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.n_features + 10, out_channels=self.ngf * 8,
                               kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.LeakyReLU(0.2)
        )

        self.hidden1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 8, out_channels=self.ngf * 4,
                               kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ngf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 4, out_channels=self.ngf * 2,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 2, out_channels=self.ngf,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.out = nn.Sequential(
            # "out_channels=1" because gray scale
            nn.ConvTranspose2d(in_channels=self.ngf, out_channels=1, kernel_size=4,
                               stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x, y):
        x_cond = torch.cat((x, y), dim=1)  # Combine flatten image with conditional input (class labels)

        x = self.hidden0(x_cond)           # Image goes into a "ConvTranspose2d" layer
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Logger:
    def __init__(self, model_name, model1, model2, m1_optimizer, m2_optimizer, model_parameter, train_loader):
        self.out_dir = "data"
        self.model_name = model_name
        self.train_loader = train_loader
        self.model1 = model1
        self.model2 = model2
        self.model_parameter = model_parameter
        self.m1_optimizer = m1_optimizer
        self.m2_optimizer = m2_optimizer

        # Exclude Epochs of the model name. This make sense e.g. when we stop a training progress and continue later on.
        self.experiment_name = '_'.join("{!s}={!r}".format(k, v) for (k, v) in model_parameter.items())\
            .replace("Epochs" + "=" + str(model_parameter["Epochs"]), "")

        self.d_error = 0
        self.g_error = 0

        self.tb = SummaryWriter(log_dir=str(self.out_dir + "/log/" + self.model_name + "/runs/" + self.experiment_name))

        self.path_image = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/images/{self.experiment_name}')
        self.path_model = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/model/{self.experiment_name}')

        try:
            os.makedirs(self.path_image)
        except Exception as e:
            print("WARNING: ", str(e))

        try:
            os.makedirs(self.path_model)
        except Exception as e:
            print("WARNING: ", str(e))

    def log_graph(self, model1_input, model2_input, model1_label, model2_label):
        self.tb.add_graph(self.model1, input_to_model=(model1_input, model1_label))
        self.tb.add_graph(self.model2, input_to_model=(model2_input, model2_label))

    def log(self, num_epoch, d_error, g_error):
        self.d_error = d_error
        self.g_error = g_error

        self.tb.add_scalar("Discriminator Train Error", self.d_error, num_epoch)
        self.tb.add_scalar("Generator Train Error", self.g_error, num_epoch)

    def log_image(self, images, epoch, batch_num):
        grid = torchvision.utils.make_grid(images)
        torchvision.utils.save_image(grid, f'{self.path_image}\\Epoch_{epoch}_batch_{batch_num}.png')

        self.tb.add_image("Generator Image", grid)

    def log_histogramm(self):
        for name, param in self.model2.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'gen_{name}.grad', param.grad, self.model_parameter["Epochs"])

        for name, param in self.model1.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'dis_{name}.grad', param.grad, self.model_parameter["Epochs"])

    def log_model(self, num_epoch):
        torch.save({
            "epoch": num_epoch,
            "model_generator_state_dict": self.model1.state_dict(),
            "model_discriminator_state_dict": self.model2.state_dict(),
            "optimizer_generator_state_dict":  self.m1_optimizer.state_dict(),
            "optimizer_discriminator_state_dict":  self.m2_optimizer.state_dict(),
        }, str(self.path_model + f'\\{time.time()}_epoch{num_epoch}.pth'))

    def close(self, logger, images, num_epoch,  d_error, g_error):
        logger.log_model(num_epoch)
        logger.log_histogramm()
        logger.log(num_epoch, d_error, g_error)
        self.tb.close()

    def display_stats(self, epoch, batch_num, dis_error, gen_error):
        print(f'Epoch: [{epoch}/{self.model_parameter["Epochs"]}] '
              f'Batch: [{batch_num}/{len(self.train_loader)}] '
              f'Loss_D: {dis_error.data.cpu()}, '
              f'Loss_G: {gen_error.data.cpu()}')


def get_MNIST_dataset(num_workers_loader, model_parameter, out_dir="data"):
    compose = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.CenterCrop((64, 64)),
        transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    dataset = datasets.MNIST(
        root=out_dir,
        train=True,
        download=True,
        transform=compose
    )

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=model_parameter["batch_size"],
                                               num_workers=num_workers_loader,
                                               shuffle=model_parameter["shuffle"])

    return dataset, train_loader


def train_discriminator(p_optimizer, p_noise, p_images, p_fake_target, p_real_target, p_images_labels, p_fake_labels, device):
    p_optimizer.zero_grad()

    # 1.1 Train on real data
    pred_dis_real = discriminator(p_images, p_images_labels)
    error_real = loss(pred_dis_real, p_real_target)

    error_real.backward()

    # 1.2 Train on fake data
    fake_data = generator(p_noise, p_fake_labels).detach()
    fake_data = add_noise_to_image(fake_data, device)
    pred_dis_fake = discriminator(fake_data, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_fake_target)

    error_fake.backward()

    p_optimizer.step()

    return error_fake + error_real


def train_generator(p_optimizer, p_noise, p_real_target, p_fake_labels, device):
    p_optimizer.zero_grad()

    fake_images = generator(p_noise, p_fake_labels)
    fake_images = add_noise_to_image(fake_images, device)
    pred_dis_fake = discriminator(fake_images, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_real_target)  # because
    """
    We use "p_real_target" instead of "p_fake_target" because we want to 
    maximize that the discriminator is wrong.
    """

    error_fake.backward()

    p_optimizer.step()

    return fake_images, pred_dis_fake, error_fake


# TODO change to a Truncated normal distribution
def get_noise(batch_size, n_features=100):
    return torch.FloatTensor(batch_size, n_features, 1, 1).uniform_(-1, 1)


# We flip label of real and fate data. Better gradient flow I have told
def get_real_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.0, 0.2)


def get_fake_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.8, 1.1)


def image_to_vector(images):
    return torch.flatten(images, start_dim=1, end_dim=-1)


def vector_to_image(images):
    return images.view(images.size(0), 1, 28, 28)


def get_rand_labels(batch_size):
    return torch.randint(low=0, high=9, size=(batch_size,))


def load_model(model_load_path):
    if model_load_path:
        checkpoint = torch.load(model_load_path)

        discriminator.load_state_dict(checkpoint["model_discriminator_state_dict"])
        generator.load_state_dict(checkpoint["model_generator_state_dict"])

        dis_opti.load_state_dict(checkpoint["optimizer_discriminator_state_dict"])
        gen_opti.load_state_dict(checkpoint["optimizer_generator_state_dict"])

        return checkpoint["epoch"]

    else:
        return 0


def init_model_optimizer(model_parameter, device):
    # Initialize the Models
    discriminator = Discriminator(ndf=model_parameter["ndf"], dropout_value=model_parameter["dropout"]).to(device)
    generator = Generator(ngf=model_parameter["ngf"], dropout_value=model_parameter["dropout"]).to(device)

    # train
    dis_opti = optim.Adam(discriminator.parameters(), lr=model_parameter["learning_rate_dis"], betas=(0.5, 0.999))
    gen_opti = optim.Adam(generator.parameters(), lr=model_parameter["learning_rate_gen"], betas=(0.5, 0.999))

    return discriminator, generator, dis_opti, gen_opti


def get_hot_vector_encode(labels, device):
    return torch.eye(10)[labels].view(-1, 10, 1, 1).to(device)


def add_noise_to_image(images, device, level_of_noise=0.1):
    return images[0].to(device) + (level_of_noise) * torch.randn(images.shape).to(device)


if __name__ == "__main__":
    # Hyperparameter
    model_parameter = {
        "batch_size": 500,
        "learning_rate_dis": 0.0002,
        "learning_rate_gen": 0.0002,
        "shuffle": False,
        "Epochs": 10,
        "ndf": 64,
        "ngf": 64,
        "dropout": 0.5
    }

    # Parameter
    r_frequent = 10        # How many samples we save for replay per batch (batch_size / r_frequent).
    model_name = "CDCGAN"   # The name of you model e.g. "Gan"
    num_workers_loader = 1  # How many workers should load the data
    sample_save_size = 16   # How many numbers your saved imaged should show
    device = "cuda"         # Which device should be used to train the neural network
    model_load_path = ""    # If set load model instead of training from new
    num_epoch_log = 1       # How frequent you want to log/
    torch.manual_seed(43)   # Sets a seed for torch for reproducibility

    dataset_train, train_loader = get_MNIST_dataset(num_workers_loader, model_parameter)  # Get dataset

    # Initialize the Models and optimizer
    discriminator, generator, dis_opti, gen_opti = init_model_optimizer(model_parameter, device)  # Init model/Optimizer

    start_epoch = load_model(model_load_path)  # when we want to load a model

    # Init Logger
    logger = Logger(model_name, generator, discriminator, gen_opti, dis_opti, model_parameter, train_loader)

    loss = nn.BCELoss()

    images, labels = next(iter(train_loader))  # For logging

    # For testing
    # pred = generator(get_noise(model_parameter["batch_size"]).to(device), get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device))
    # dis = discriminator(images.to(device), get_hot_vector_encode(labels, device))

    logger.log_graph(get_noise(model_parameter["batch_size"]).to(device), images.to(device),
                     get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device),
                     get_hot_vector_encode(labels, device))


    # Array to store
    exp_replay = torch.tensor([]).to(device)

    for num_epoch in range(start_epoch, model_parameter["Epochs"]):
        for batch_num, data_loader in enumerate(train_loader):
            images, labels = data_loader
            images = add_noise_to_image(images, device)  # Add noise to the images

            # 1. Train Discriminator
            dis_error = train_discriminator(
                                            dis_opti,
                                            get_noise(model_parameter["batch_size"]).to(device),
                                            images.to(device),
                                            get_fake_data_target(model_parameter["batch_size"]).to(device),
                                            get_real_data_target(model_parameter["batch_size"]).to(device),
                                            get_hot_vector_encode(labels, device),
                                            get_hot_vector_encode(
                                                get_rand_labels(model_parameter["batch_size"]), device),
                                            device
                                            )

            # 2. Train Generator
            fake_image, pred_dis_fake, gen_error = train_generator(
                                                                  gen_opti,
                                                                  get_noise(model_parameter["batch_size"]).to(device),
                                                                  get_real_data_target(model_parameter["batch_size"]).to(device),
                                                                  get_hot_vector_encode(
                                                                      get_rand_labels(model_parameter["batch_size"]),
                                                                      device),
                                                                  device
                                                                  )


            # Store a random point for experience replay
            perm = torch.randperm(fake_image.size(0))
            r_idx = perm[:max(1, int(model_parameter["batch_size"] / r_frequent))]
            r_samples = add_noise_to_image(fake_image[r_idx], device)
            exp_replay = torch.cat((exp_replay, r_samples), 0).detach()

            if exp_replay.size(0) >= model_parameter["batch_size"]:
                # Train on experienced data
                dis_opti.zero_grad()

                r_label = get_hot_vector_encode(torch.zeros(exp_replay.size(0)).numpy(), device)
                pred_dis_real = discriminator(exp_replay, r_label)
                error_real = loss(pred_dis_real,  get_fake_data_target(exp_replay.size(0)).to(device))

                error_real.backward()

                dis_opti.step()

                print(f'Epoch: [{num_epoch}/{model_parameter["Epochs"]}] '
                      f'Batch: Replay/Experience batch '
                      f'Loss_D: {error_real.data.cpu()}, '
                      )

                exp_replay = torch.tensor([]).to(device)

            logger.display_stats(epoch=num_epoch, batch_num=batch_num, dis_error=dis_error, gen_error=gen_error)

            if batch_num % 100 == 0:
                logger.log_image(fake_image[:sample_save_size], num_epoch, batch_num)

        logger.log(num_epoch, dis_error, gen_error)
        if num_epoch % num_epoch_log == 0:
            logger.log_model(num_epoch)
            logger.log_histogramm()
    logger.close(logger, fake_image[:sample_save_size], num_epoch, dis_error, gen_error)
  • 条件半监督学习模型的建立
  • 使用批处理范数
  • 在发生器和鉴别器的输入/输出层之外的每一层上使用漏极
  • 消除过度自信的标签平滑
  • 向图像添加噪波(我猜您称之为实例噪波)以获得更好的数据分布
  • 使用泄漏的relu避免渐变消失
  • 使用重放缓冲区来防止遗忘学习内容和过度拟合
  • 玩超参数游戏
  • 将其与PyTorch教程中的模型进行比较
我的模型生成的图像:

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import torch.nn.functional as F
from torch import optim as optim
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
import time


class Discriminator(torch.nn.Module):
    def __init__(self, ndf=16, dropout_value=0.5):  # ndf feature map discriminator
        super().__init__()
        self.ndf = ndf
        self.droupout_value = dropout_value

        self.condi = nn.Sequential(
            nn.Linear(in_features=10, out_features=64 * 64)
        )

        self.hidden0 = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
        )
        self.hidden1 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf, out_channels=self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden2 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 2, out_channels=self.ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden3 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 4, out_channels=self.ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.out = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            torch.nn.Sigmoid()
        )

    def forward(self, x, y):
        y = self.condi(y.view(-1, 10))
        y = y.view(-1, 1, 64, 64)

        x = torch.cat((x, y), dim=1)

        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Generator(torch.nn.Module):
    def __init__(self, n_features=100, ngf=16, c_channels=1, dropout_value=0.5):  # ngf feature map of generator
        super().__init__()
        self.ngf = ngf
        self.n_features = n_features
        self.c_channels = c_channels
        self.droupout_value = dropout_value

        self.hidden0 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.n_features + 10, out_channels=self.ngf * 8,
                               kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.LeakyReLU(0.2)
        )

        self.hidden1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 8, out_channels=self.ngf * 4,
                               kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ngf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 4, out_channels=self.ngf * 2,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 2, out_channels=self.ngf,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.out = nn.Sequential(
            # "out_channels=1" because gray scale
            nn.ConvTranspose2d(in_channels=self.ngf, out_channels=1, kernel_size=4,
                               stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x, y):
        x_cond = torch.cat((x, y), dim=1)  # Combine flatten image with conditional input (class labels)

        x = self.hidden0(x_cond)           # Image goes into a "ConvTranspose2d" layer
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Logger:
    def __init__(self, model_name, model1, model2, m1_optimizer, m2_optimizer, model_parameter, train_loader):
        self.out_dir = "data"
        self.model_name = model_name
        self.train_loader = train_loader
        self.model1 = model1
        self.model2 = model2
        self.model_parameter = model_parameter
        self.m1_optimizer = m1_optimizer
        self.m2_optimizer = m2_optimizer

        # Exclude Epochs of the model name. This make sense e.g. when we stop a training progress and continue later on.
        self.experiment_name = '_'.join("{!s}={!r}".format(k, v) for (k, v) in model_parameter.items())\
            .replace("Epochs" + "=" + str(model_parameter["Epochs"]), "")

        self.d_error = 0
        self.g_error = 0

        self.tb = SummaryWriter(log_dir=str(self.out_dir + "/log/" + self.model_name + "/runs/" + self.experiment_name))

        self.path_image = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/images/{self.experiment_name}')
        self.path_model = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/model/{self.experiment_name}')

        try:
            os.makedirs(self.path_image)
        except Exception as e:
            print("WARNING: ", str(e))

        try:
            os.makedirs(self.path_model)
        except Exception as e:
            print("WARNING: ", str(e))

    def log_graph(self, model1_input, model2_input, model1_label, model2_label):
        self.tb.add_graph(self.model1, input_to_model=(model1_input, model1_label))
        self.tb.add_graph(self.model2, input_to_model=(model2_input, model2_label))

    def log(self, num_epoch, d_error, g_error):
        self.d_error = d_error
        self.g_error = g_error

        self.tb.add_scalar("Discriminator Train Error", self.d_error, num_epoch)
        self.tb.add_scalar("Generator Train Error", self.g_error, num_epoch)

    def log_image(self, images, epoch, batch_num):
        grid = torchvision.utils.make_grid(images)
        torchvision.utils.save_image(grid, f'{self.path_image}\\Epoch_{epoch}_batch_{batch_num}.png')

        self.tb.add_image("Generator Image", grid)

    def log_histogramm(self):
        for name, param in self.model2.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'gen_{name}.grad', param.grad, self.model_parameter["Epochs"])

        for name, param in self.model1.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'dis_{name}.grad', param.grad, self.model_parameter["Epochs"])

    def log_model(self, num_epoch):
        torch.save({
            "epoch": num_epoch,
            "model_generator_state_dict": self.model1.state_dict(),
            "model_discriminator_state_dict": self.model2.state_dict(),
            "optimizer_generator_state_dict":  self.m1_optimizer.state_dict(),
            "optimizer_discriminator_state_dict":  self.m2_optimizer.state_dict(),
        }, str(self.path_model + f'\\{time.time()}_epoch{num_epoch}.pth'))

    def close(self, logger, images, num_epoch,  d_error, g_error):
        logger.log_model(num_epoch)
        logger.log_histogramm()
        logger.log(num_epoch, d_error, g_error)
        self.tb.close()

    def display_stats(self, epoch, batch_num, dis_error, gen_error):
        print(f'Epoch: [{epoch}/{self.model_parameter["Epochs"]}] '
              f'Batch: [{batch_num}/{len(self.train_loader)}] '
              f'Loss_D: {dis_error.data.cpu()}, '
              f'Loss_G: {gen_error.data.cpu()}')


def get_MNIST_dataset(num_workers_loader, model_parameter, out_dir="data"):
    compose = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.CenterCrop((64, 64)),
        transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    dataset = datasets.MNIST(
        root=out_dir,
        train=True,
        download=True,
        transform=compose
    )

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=model_parameter["batch_size"],
                                               num_workers=num_workers_loader,
                                               shuffle=model_parameter["shuffle"])

    return dataset, train_loader


def train_discriminator(p_optimizer, p_noise, p_images, p_fake_target, p_real_target, p_images_labels, p_fake_labels, device):
    p_optimizer.zero_grad()

    # 1.1 Train on real data
    pred_dis_real = discriminator(p_images, p_images_labels)
    error_real = loss(pred_dis_real, p_real_target)

    error_real.backward()

    # 1.2 Train on fake data
    fake_data = generator(p_noise, p_fake_labels).detach()
    fake_data = add_noise_to_image(fake_data, device)
    pred_dis_fake = discriminator(fake_data, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_fake_target)

    error_fake.backward()

    p_optimizer.step()

    return error_fake + error_real


def train_generator(p_optimizer, p_noise, p_real_target, p_fake_labels, device):
    p_optimizer.zero_grad()

    fake_images = generator(p_noise, p_fake_labels)
    fake_images = add_noise_to_image(fake_images, device)
    pred_dis_fake = discriminator(fake_images, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_real_target)  # because
    """
    We use "p_real_target" instead of "p_fake_target" because we want to 
    maximize that the discriminator is wrong.
    """

    error_fake.backward()

    p_optimizer.step()

    return fake_images, pred_dis_fake, error_fake


# TODO change to a Truncated normal distribution
def get_noise(batch_size, n_features=100):
    return torch.FloatTensor(batch_size, n_features, 1, 1).uniform_(-1, 1)


# We flip label of real and fate data. Better gradient flow I have told
def get_real_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.0, 0.2)


def get_fake_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.8, 1.1)


def image_to_vector(images):
    return torch.flatten(images, start_dim=1, end_dim=-1)


def vector_to_image(images):
    return images.view(images.size(0), 1, 28, 28)


def get_rand_labels(batch_size):
    return torch.randint(low=0, high=9, size=(batch_size,))


def load_model(model_load_path):
    if model_load_path:
        checkpoint = torch.load(model_load_path)

        discriminator.load_state_dict(checkpoint["model_discriminator_state_dict"])
        generator.load_state_dict(checkpoint["model_generator_state_dict"])

        dis_opti.load_state_dict(checkpoint["optimizer_discriminator_state_dict"])
        gen_opti.load_state_dict(checkpoint["optimizer_generator_state_dict"])

        return checkpoint["epoch"]

    else:
        return 0


def init_model_optimizer(model_parameter, device):
    # Initialize the Models
    discriminator = Discriminator(ndf=model_parameter["ndf"], dropout_value=model_parameter["dropout"]).to(device)
    generator = Generator(ngf=model_parameter["ngf"], dropout_value=model_parameter["dropout"]).to(device)

    # train
    dis_opti = optim.Adam(discriminator.parameters(), lr=model_parameter["learning_rate_dis"], betas=(0.5, 0.999))
    gen_opti = optim.Adam(generator.parameters(), lr=model_parameter["learning_rate_gen"], betas=(0.5, 0.999))

    return discriminator, generator, dis_opti, gen_opti


def get_hot_vector_encode(labels, device):
    return torch.eye(10)[labels].view(-1, 10, 1, 1).to(device)


def add_noise_to_image(images, device, level_of_noise=0.1):
    return images[0].to(device) + (level_of_noise) * torch.randn(images.shape).to(device)


if __name__ == "__main__":
    # Hyperparameter
    model_parameter = {
        "batch_size": 500,
        "learning_rate_dis": 0.0002,
        "learning_rate_gen": 0.0002,
        "shuffle": False,
        "Epochs": 10,
        "ndf": 64,
        "ngf": 64,
        "dropout": 0.5
    }

    # Parameter
    r_frequent = 10        # How many samples we save for replay per batch (batch_size / r_frequent).
    model_name = "CDCGAN"   # The name of you model e.g. "Gan"
    num_workers_loader = 1  # How many workers should load the data
    sample_save_size = 16   # How many numbers your saved imaged should show
    device = "cuda"         # Which device should be used to train the neural network
    model_load_path = ""    # If set load model instead of training from new
    num_epoch_log = 1       # How frequent you want to log/
    torch.manual_seed(43)   # Sets a seed for torch for reproducibility

    dataset_train, train_loader = get_MNIST_dataset(num_workers_loader, model_parameter)  # Get dataset

    # Initialize the Models and optimizer
    discriminator, generator, dis_opti, gen_opti = init_model_optimizer(model_parameter, device)  # Init model/Optimizer

    start_epoch = load_model(model_load_path)  # when we want to load a model

    # Init Logger
    logger = Logger(model_name, generator, discriminator, gen_opti, dis_opti, model_parameter, train_loader)

    loss = nn.BCELoss()

    images, labels = next(iter(train_loader))  # For logging

    # For testing
    # pred = generator(get_noise(model_parameter["batch_size"]).to(device), get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device))
    # dis = discriminator(images.to(device), get_hot_vector_encode(labels, device))

    logger.log_graph(get_noise(model_parameter["batch_size"]).to(device), images.to(device),
                     get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device),
                     get_hot_vector_encode(labels, device))


    # Array to store
    exp_replay = torch.tensor([]).to(device)

    for num_epoch in range(start_epoch, model_parameter["Epochs"]):
        for batch_num, data_loader in enumerate(train_loader):
            images, labels = data_loader
            images = add_noise_to_image(images, device)  # Add noise to the images

            # 1. Train Discriminator
            dis_error = train_discriminator(
                                            dis_opti,
                                            get_noise(model_parameter["batch_size"]).to(device),
                                            images.to(device),
                                            get_fake_data_target(model_parameter["batch_size"]).to(device),
                                            get_real_data_target(model_parameter["batch_size"]).to(device),
                                            get_hot_vector_encode(labels, device),
                                            get_hot_vector_encode(
                                                get_rand_labels(model_parameter["batch_size"]), device),
                                            device
                                            )

            # 2. Train Generator
            fake_image, pred_dis_fake, gen_error = train_generator(
                                                                  gen_opti,
                                                                  get_noise(model_parameter["batch_size"]).to(device),
                                                                  get_real_data_target(model_parameter["batch_size"]).to(device),
                                                                  get_hot_vector_encode(
                                                                      get_rand_labels(model_parameter["batch_size"]),
                                                                      device),
                                                                  device
                                                                  )


            # Store a random point for experience replay
            perm = torch.randperm(fake_image.size(0))
            r_idx = perm[:max(1, int(model_parameter["batch_size"] / r_frequent))]
            r_samples = add_noise_to_image(fake_image[r_idx], device)
            exp_replay = torch.cat((exp_replay, r_samples), 0).detach()

            if exp_replay.size(0) >= model_parameter["batch_size"]:
                # Train on experienced data
                dis_opti.zero_grad()

                r_label = get_hot_vector_encode(torch.zeros(exp_replay.size(0)).numpy(), device)
                pred_dis_real = discriminator(exp_replay, r_label)
                error_real = loss(pred_dis_real,  get_fake_data_target(exp_replay.size(0)).to(device))

                error_real.backward()

                dis_opti.step()

                print(f'Epoch: [{num_epoch}/{model_parameter["Epochs"]}] '
                      f'Batch: Replay/Experience batch '
                      f'Loss_D: {error_real.data.cpu()}, '
                      )

                exp_replay = torch.tensor([]).to(device)

            logger.display_stats(epoch=num_epoch, batch_num=batch_num, dis_error=dis_error, gen_error=gen_error)

            if batch_num % 100 == 0:
                logger.log_image(fake_image[:sample_save_size], num_epoch, batch_num)

        logger.log(num_epoch, dis_error, gen_error)
        if num_epoch % num_epoch_log == 0:
            logger.log_model(num_epoch)
            logger.log_histogramm()
    logger.close(logger, fake_image[:sample_save_size], num_epoch, dis_error, gen_error)
超参数:

批大小=50,学习率=0.0001,学习率=0.0003,随机数=True,ndf=64,ngf=64,Droopout=0.5

批量大小=50,学习率鉴别器=0.0003,学习率生成器=0.0003,随机播放=True,ndf=64,ngf=64,辍学=0

生成的图像:

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import torch.nn.functional as F
from torch import optim as optim
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
import time


class Discriminator(torch.nn.Module):
    def __init__(self, ndf=16, dropout_value=0.5):  # ndf feature map discriminator
        super().__init__()
        self.ndf = ndf
        self.droupout_value = dropout_value

        self.condi = nn.Sequential(
            nn.Linear(in_features=10, out_features=64 * 64)
        )

        self.hidden0 = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
        )
        self.hidden1 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf, out_channels=self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden2 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 2, out_channels=self.ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden3 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 4, out_channels=self.ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.out = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            torch.nn.Sigmoid()
        )

    def forward(self, x, y):
        y = self.condi(y.view(-1, 10))
        y = y.view(-1, 1, 64, 64)

        x = torch.cat((x, y), dim=1)

        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Generator(torch.nn.Module):
    def __init__(self, n_features=100, ngf=16, c_channels=1, dropout_value=0.5):  # ngf feature map of generator
        super().__init__()
        self.ngf = ngf
        self.n_features = n_features
        self.c_channels = c_channels
        self.droupout_value = dropout_value

        self.hidden0 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.n_features + 10, out_channels=self.ngf * 8,
                               kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.LeakyReLU(0.2)
        )

        self.hidden1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 8, out_channels=self.ngf * 4,
                               kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ngf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 4, out_channels=self.ngf * 2,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 2, out_channels=self.ngf,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.out = nn.Sequential(
            # "out_channels=1" because gray scale
            nn.ConvTranspose2d(in_channels=self.ngf, out_channels=1, kernel_size=4,
                               stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x, y):
        x_cond = torch.cat((x, y), dim=1)  # Combine flatten image with conditional input (class labels)

        x = self.hidden0(x_cond)           # Image goes into a "ConvTranspose2d" layer
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Logger:
    def __init__(self, model_name, model1, model2, m1_optimizer, m2_optimizer, model_parameter, train_loader):
        self.out_dir = "data"
        self.model_name = model_name
        self.train_loader = train_loader
        self.model1 = model1
        self.model2 = model2
        self.model_parameter = model_parameter
        self.m1_optimizer = m1_optimizer
        self.m2_optimizer = m2_optimizer

        # Exclude Epochs of the model name. This make sense e.g. when we stop a training progress and continue later on.
        self.experiment_name = '_'.join("{!s}={!r}".format(k, v) for (k, v) in model_parameter.items())\
            .replace("Epochs" + "=" + str(model_parameter["Epochs"]), "")

        self.d_error = 0
        self.g_error = 0

        self.tb = SummaryWriter(log_dir=str(self.out_dir + "/log/" + self.model_name + "/runs/" + self.experiment_name))

        self.path_image = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/images/{self.experiment_name}')
        self.path_model = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/model/{self.experiment_name}')

        try:
            os.makedirs(self.path_image)
        except Exception as e:
            print("WARNING: ", str(e))

        try:
            os.makedirs(self.path_model)
        except Exception as e:
            print("WARNING: ", str(e))

    def log_graph(self, model1_input, model2_input, model1_label, model2_label):
        self.tb.add_graph(self.model1, input_to_model=(model1_input, model1_label))
        self.tb.add_graph(self.model2, input_to_model=(model2_input, model2_label))

    def log(self, num_epoch, d_error, g_error):
        self.d_error = d_error
        self.g_error = g_error

        self.tb.add_scalar("Discriminator Train Error", self.d_error, num_epoch)
        self.tb.add_scalar("Generator Train Error", self.g_error, num_epoch)

    def log_image(self, images, epoch, batch_num):
        grid = torchvision.utils.make_grid(images)
        torchvision.utils.save_image(grid, f'{self.path_image}\\Epoch_{epoch}_batch_{batch_num}.png')

        self.tb.add_image("Generator Image", grid)

    def log_histogramm(self):
        for name, param in self.model2.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'gen_{name}.grad', param.grad, self.model_parameter["Epochs"])

        for name, param in self.model1.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'dis_{name}.grad', param.grad, self.model_parameter["Epochs"])

    def log_model(self, num_epoch):
        torch.save({
            "epoch": num_epoch,
            "model_generator_state_dict": self.model1.state_dict(),
            "model_discriminator_state_dict": self.model2.state_dict(),
            "optimizer_generator_state_dict":  self.m1_optimizer.state_dict(),
            "optimizer_discriminator_state_dict":  self.m2_optimizer.state_dict(),
        }, str(self.path_model + f'\\{time.time()}_epoch{num_epoch}.pth'))

    def close(self, logger, images, num_epoch,  d_error, g_error):
        logger.log_model(num_epoch)
        logger.log_histogramm()
        logger.log(num_epoch, d_error, g_error)
        self.tb.close()

    def display_stats(self, epoch, batch_num, dis_error, gen_error):
        print(f'Epoch: [{epoch}/{self.model_parameter["Epochs"]}] '
              f'Batch: [{batch_num}/{len(self.train_loader)}] '
              f'Loss_D: {dis_error.data.cpu()}, '
              f'Loss_G: {gen_error.data.cpu()}')


def get_MNIST_dataset(num_workers_loader, model_parameter, out_dir="data"):
    compose = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.CenterCrop((64, 64)),
        transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    dataset = datasets.MNIST(
        root=out_dir,
        train=True,
        download=True,
        transform=compose
    )

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=model_parameter["batch_size"],
                                               num_workers=num_workers_loader,
                                               shuffle=model_parameter["shuffle"])

    return dataset, train_loader


def train_discriminator(p_optimizer, p_noise, p_images, p_fake_target, p_real_target, p_images_labels, p_fake_labels, device):
    p_optimizer.zero_grad()

    # 1.1 Train on real data
    pred_dis_real = discriminator(p_images, p_images_labels)
    error_real = loss(pred_dis_real, p_real_target)

    error_real.backward()

    # 1.2 Train on fake data
    fake_data = generator(p_noise, p_fake_labels).detach()
    fake_data = add_noise_to_image(fake_data, device)
    pred_dis_fake = discriminator(fake_data, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_fake_target)

    error_fake.backward()

    p_optimizer.step()

    return error_fake + error_real


def train_generator(p_optimizer, p_noise, p_real_target, p_fake_labels, device):
    p_optimizer.zero_grad()

    fake_images = generator(p_noise, p_fake_labels)
    fake_images = add_noise_to_image(fake_images, device)
    pred_dis_fake = discriminator(fake_images, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_real_target)  # because
    """
    We use "p_real_target" instead of "p_fake_target" because we want to 
    maximize that the discriminator is wrong.
    """

    error_fake.backward()

    p_optimizer.step()

    return fake_images, pred_dis_fake, error_fake


# TODO change to a Truncated normal distribution
def get_noise(batch_size, n_features=100):
    return torch.FloatTensor(batch_size, n_features, 1, 1).uniform_(-1, 1)


# We flip label of real and fate data. Better gradient flow I have told
def get_real_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.0, 0.2)


def get_fake_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.8, 1.1)


def image_to_vector(images):
    return torch.flatten(images, start_dim=1, end_dim=-1)


def vector_to_image(images):
    return images.view(images.size(0), 1, 28, 28)


def get_rand_labels(batch_size):
    return torch.randint(low=0, high=9, size=(batch_size,))


def load_model(model_load_path):
    if model_load_path:
        checkpoint = torch.load(model_load_path)

        discriminator.load_state_dict(checkpoint["model_discriminator_state_dict"])
        generator.load_state_dict(checkpoint["model_generator_state_dict"])

        dis_opti.load_state_dict(checkpoint["optimizer_discriminator_state_dict"])
        gen_opti.load_state_dict(checkpoint["optimizer_generator_state_dict"])

        return checkpoint["epoch"]

    else:
        return 0


def init_model_optimizer(model_parameter, device):
    # Initialize the Models
    discriminator = Discriminator(ndf=model_parameter["ndf"], dropout_value=model_parameter["dropout"]).to(device)
    generator = Generator(ngf=model_parameter["ngf"], dropout_value=model_parameter["dropout"]).to(device)

    # train
    dis_opti = optim.Adam(discriminator.parameters(), lr=model_parameter["learning_rate_dis"], betas=(0.5, 0.999))
    gen_opti = optim.Adam(generator.parameters(), lr=model_parameter["learning_rate_gen"], betas=(0.5, 0.999))

    return discriminator, generator, dis_opti, gen_opti


def get_hot_vector_encode(labels, device):
    return torch.eye(10)[labels].view(-1, 10, 1, 1).to(device)


def add_noise_to_image(images, device, level_of_noise=0.1):
    return images[0].to(device) + (level_of_noise) * torch.randn(images.shape).to(device)


if __name__ == "__main__":
    # Hyperparameter
    model_parameter = {
        "batch_size": 500,
        "learning_rate_dis": 0.0002,
        "learning_rate_gen": 0.0002,
        "shuffle": False,
        "Epochs": 10,
        "ndf": 64,
        "ngf": 64,
        "dropout": 0.5
    }

    # Parameter
    r_frequent = 10        # How many samples we save for replay per batch (batch_size / r_frequent).
    model_name = "CDCGAN"   # The name of you model e.g. "Gan"
    num_workers_loader = 1  # How many workers should load the data
    sample_save_size = 16   # How many numbers your saved imaged should show
    device = "cuda"         # Which device should be used to train the neural network
    model_load_path = ""    # If set load model instead of training from new
    num_epoch_log = 1       # How frequent you want to log/
    torch.manual_seed(43)   # Sets a seed for torch for reproducibility

    dataset_train, train_loader = get_MNIST_dataset(num_workers_loader, model_parameter)  # Get dataset

    # Initialize the Models and optimizer
    discriminator, generator, dis_opti, gen_opti = init_model_optimizer(model_parameter, device)  # Init model/Optimizer

    start_epoch = load_model(model_load_path)  # when we want to load a model

    # Init Logger
    logger = Logger(model_name, generator, discriminator, gen_opti, dis_opti, model_parameter, train_loader)

    loss = nn.BCELoss()

    images, labels = next(iter(train_loader))  # For logging

    # For testing
    # pred = generator(get_noise(model_parameter["batch_size"]).to(device), get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device))
    # dis = discriminator(images.to(device), get_hot_vector_encode(labels, device))

    logger.log_graph(get_noise(model_parameter["batch_size"]).to(device), images.to(device),
                     get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device),
                     get_hot_vector_encode(labels, device))


    # Array to store
    exp_replay = torch.tensor([]).to(device)

    for num_epoch in range(start_epoch, model_parameter["Epochs"]):
        for batch_num, data_loader in enumerate(train_loader):
            images, labels = data_loader
            images = add_noise_to_image(images, device)  # Add noise to the images

            # 1. Train Discriminator
            dis_error = train_discriminator(
                                            dis_opti,
                                            get_noise(model_parameter["batch_size"]).to(device),
                                            images.to(device),
                                            get_fake_data_target(model_parameter["batch_size"]).to(device),
                                            get_real_data_target(model_parameter["batch_size"]).to(device),
                                            get_hot_vector_encode(labels, device),
                                            get_hot_vector_encode(
                                                get_rand_labels(model_parameter["batch_size"]), device),
                                            device
                                            )

            # 2. Train Generator
            fake_image, pred_dis_fake, gen_error = train_generator(
                                                                  gen_opti,
                                                                  get_noise(model_parameter["batch_size"]).to(device),
                                                                  get_real_data_target(model_parameter["batch_size"]).to(device),
                                                                  get_hot_vector_encode(
                                                                      get_rand_labels(model_parameter["batch_size"]),
                                                                      device),
                                                                  device
                                                                  )


            # Store a random point for experience replay
            perm = torch.randperm(fake_image.size(0))
            r_idx = perm[:max(1, int(model_parameter["batch_size"] / r_frequent))]
            r_samples = add_noise_to_image(fake_image[r_idx], device)
            exp_replay = torch.cat((exp_replay, r_samples), 0).detach()

            if exp_replay.size(0) >= model_parameter["batch_size"]:
                # Train on experienced data
                dis_opti.zero_grad()

                r_label = get_hot_vector_encode(torch.zeros(exp_replay.size(0)).numpy(), device)
                pred_dis_real = discriminator(exp_replay, r_label)
                error_real = loss(pred_dis_real,  get_fake_data_target(exp_replay.size(0)).to(device))

                error_real.backward()

                dis_opti.step()

                print(f'Epoch: [{num_epoch}/{model_parameter["Epochs"]}] '
                      f'Batch: Replay/Experience batch '
                      f'Loss_D: {error_real.data.cpu()}, '
                      )

                exp_replay = torch.tensor([]).to(device)

            logger.display_stats(epoch=num_epoch, batch_num=batch_num, dis_error=dis_error, gen_error=gen_error)

            if batch_num % 100 == 0:
                logger.log_image(fake_image[:sample_save_size], num_epoch, batch_num)

        logger.log(num_epoch, dis_error, gen_error)
        if num_epoch % num_epoch_log == 0:
            logger.log_model(num_epoch)
            logger.log_histogramm()
    logger.close(logger, fake_image[:sample_save_size], num_epoch, dis_error, gen_error)

作为比较,这里是pytorch turoial的DCGAN图像:

我的代码:

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import torch.nn.functional as F
from torch import optim as optim
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
import time


class Discriminator(torch.nn.Module):
    def __init__(self, ndf=16, dropout_value=0.5):  # ndf feature map discriminator
        super().__init__()
        self.ndf = ndf
        self.droupout_value = dropout_value

        self.condi = nn.Sequential(
            nn.Linear(in_features=10, out_features=64 * 64)
        )

        self.hidden0 = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
        )
        self.hidden1 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf, out_channels=self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden2 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 2, out_channels=self.ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden3 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 4, out_channels=self.ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.out = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            torch.nn.Sigmoid()
        )

    def forward(self, x, y):
        y = self.condi(y.view(-1, 10))
        y = y.view(-1, 1, 64, 64)

        x = torch.cat((x, y), dim=1)

        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Generator(torch.nn.Module):
    def __init__(self, n_features=100, ngf=16, c_channels=1, dropout_value=0.5):  # ngf feature map of generator
        super().__init__()
        self.ngf = ngf
        self.n_features = n_features
        self.c_channels = c_channels
        self.droupout_value = dropout_value

        self.hidden0 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.n_features + 10, out_channels=self.ngf * 8,
                               kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.LeakyReLU(0.2)
        )

        self.hidden1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 8, out_channels=self.ngf * 4,
                               kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ngf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 4, out_channels=self.ngf * 2,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 2, out_channels=self.ngf,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.out = nn.Sequential(
            # "out_channels=1" because gray scale
            nn.ConvTranspose2d(in_channels=self.ngf, out_channels=1, kernel_size=4,
                               stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x, y):
        x_cond = torch.cat((x, y), dim=1)  # Combine flatten image with conditional input (class labels)

        x = self.hidden0(x_cond)           # Image goes into a "ConvTranspose2d" layer
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Logger:
    def __init__(self, model_name, model1, model2, m1_optimizer, m2_optimizer, model_parameter, train_loader):
        self.out_dir = "data"
        self.model_name = model_name
        self.train_loader = train_loader
        self.model1 = model1
        self.model2 = model2
        self.model_parameter = model_parameter
        self.m1_optimizer = m1_optimizer
        self.m2_optimizer = m2_optimizer

        # Exclude Epochs of the model name. This make sense e.g. when we stop a training progress and continue later on.
        self.experiment_name = '_'.join("{!s}={!r}".format(k, v) for (k, v) in model_parameter.items())\
            .replace("Epochs" + "=" + str(model_parameter["Epochs"]), "")

        self.d_error = 0
        self.g_error = 0

        self.tb = SummaryWriter(log_dir=str(self.out_dir + "/log/" + self.model_name + "/runs/" + self.experiment_name))

        self.path_image = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/images/{self.experiment_name}')
        self.path_model = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/model/{self.experiment_name}')

        try:
            os.makedirs(self.path_image)
        except Exception as e:
            print("WARNING: ", str(e))

        try:
            os.makedirs(self.path_model)
        except Exception as e:
            print("WARNING: ", str(e))

    def log_graph(self, model1_input, model2_input, model1_label, model2_label):
        self.tb.add_graph(self.model1, input_to_model=(model1_input, model1_label))
        self.tb.add_graph(self.model2, input_to_model=(model2_input, model2_label))

    def log(self, num_epoch, d_error, g_error):
        self.d_error = d_error
        self.g_error = g_error

        self.tb.add_scalar("Discriminator Train Error", self.d_error, num_epoch)
        self.tb.add_scalar("Generator Train Error", self.g_error, num_epoch)

    def log_image(self, images, epoch, batch_num):
        grid = torchvision.utils.make_grid(images)
        torchvision.utils.save_image(grid, f'{self.path_image}\\Epoch_{epoch}_batch_{batch_num}.png')

        self.tb.add_image("Generator Image", grid)

    def log_histogramm(self):
        for name, param in self.model2.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'gen_{name}.grad', param.grad, self.model_parameter["Epochs"])

        for name, param in self.model1.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'dis_{name}.grad', param.grad, self.model_parameter["Epochs"])

    def log_model(self, num_epoch):
        torch.save({
            "epoch": num_epoch,
            "model_generator_state_dict": self.model1.state_dict(),
            "model_discriminator_state_dict": self.model2.state_dict(),
            "optimizer_generator_state_dict":  self.m1_optimizer.state_dict(),
            "optimizer_discriminator_state_dict":  self.m2_optimizer.state_dict(),
        }, str(self.path_model + f'\\{time.time()}_epoch{num_epoch}.pth'))

    def close(self, logger, images, num_epoch,  d_error, g_error):
        logger.log_model(num_epoch)
        logger.log_histogramm()
        logger.log(num_epoch, d_error, g_error)
        self.tb.close()

    def display_stats(self, epoch, batch_num, dis_error, gen_error):
        print(f'Epoch: [{epoch}/{self.model_parameter["Epochs"]}] '
              f'Batch: [{batch_num}/{len(self.train_loader)}] '
              f'Loss_D: {dis_error.data.cpu()}, '
              f'Loss_G: {gen_error.data.cpu()}')


def get_MNIST_dataset(num_workers_loader, model_parameter, out_dir="data"):
    compose = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.CenterCrop((64, 64)),
        transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    dataset = datasets.MNIST(
        root=out_dir,
        train=True,
        download=True,
        transform=compose
    )

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=model_parameter["batch_size"],
                                               num_workers=num_workers_loader,
                                               shuffle=model_parameter["shuffle"])

    return dataset, train_loader


def train_discriminator(p_optimizer, p_noise, p_images, p_fake_target, p_real_target, p_images_labels, p_fake_labels, device):
    p_optimizer.zero_grad()

    # 1.1 Train on real data
    pred_dis_real = discriminator(p_images, p_images_labels)
    error_real = loss(pred_dis_real, p_real_target)

    error_real.backward()

    # 1.2 Train on fake data
    fake_data = generator(p_noise, p_fake_labels).detach()
    fake_data = add_noise_to_image(fake_data, device)
    pred_dis_fake = discriminator(fake_data, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_fake_target)

    error_fake.backward()

    p_optimizer.step()

    return error_fake + error_real


def train_generator(p_optimizer, p_noise, p_real_target, p_fake_labels, device):
    p_optimizer.zero_grad()

    fake_images = generator(p_noise, p_fake_labels)
    fake_images = add_noise_to_image(fake_images, device)
    pred_dis_fake = discriminator(fake_images, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_real_target)  # because
    """
    We use "p_real_target" instead of "p_fake_target" because we want to 
    maximize that the discriminator is wrong.
    """

    error_fake.backward()

    p_optimizer.step()

    return fake_images, pred_dis_fake, error_fake


# TODO change to a Truncated normal distribution
def get_noise(batch_size, n_features=100):
    return torch.FloatTensor(batch_size, n_features, 1, 1).uniform_(-1, 1)


# We flip label of real and fate data. Better gradient flow I have told
def get_real_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.0, 0.2)


def get_fake_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.8, 1.1)


def image_to_vector(images):
    return torch.flatten(images, start_dim=1, end_dim=-1)


def vector_to_image(images):
    return images.view(images.size(0), 1, 28, 28)


def get_rand_labels(batch_size):
    return torch.randint(low=0, high=9, size=(batch_size,))


def load_model(model_load_path):
    if model_load_path:
        checkpoint = torch.load(model_load_path)

        discriminator.load_state_dict(checkpoint["model_discriminator_state_dict"])
        generator.load_state_dict(checkpoint["model_generator_state_dict"])

        dis_opti.load_state_dict(checkpoint["optimizer_discriminator_state_dict"])
        gen_opti.load_state_dict(checkpoint["optimizer_generator_state_dict"])

        return checkpoint["epoch"]

    else:
        return 0


def init_model_optimizer(model_parameter, device):
    # Initialize the Models
    discriminator = Discriminator(ndf=model_parameter["ndf"], dropout_value=model_parameter["dropout"]).to(device)
    generator = Generator(ngf=model_parameter["ngf"], dropout_value=model_parameter["dropout"]).to(device)

    # train
    dis_opti = optim.Adam(discriminator.parameters(), lr=model_parameter["learning_rate_dis"], betas=(0.5, 0.999))
    gen_opti = optim.Adam(generator.parameters(), lr=model_parameter["learning_rate_gen"], betas=(0.5, 0.999))

    return discriminator, generator, dis_opti, gen_opti


def get_hot_vector_encode(labels, device):
    return torch.eye(10)[labels].view(-1, 10, 1, 1).to(device)


def add_noise_to_image(images, device, level_of_noise=0.1):
    return images[0].to(device) + (level_of_noise) * torch.randn(images.shape).to(device)


if __name__ == "__main__":
    # Hyperparameter
    model_parameter = {
        "batch_size": 500,
        "learning_rate_dis": 0.0002,
        "learning_rate_gen": 0.0002,
        "shuffle": False,
        "Epochs": 10,
        "ndf": 64,
        "ngf": 64,
        "dropout": 0.5
    }

    # Parameter
    r_frequent = 10        # How many samples we save for replay per batch (batch_size / r_frequent).
    model_name = "CDCGAN"   # The name of you model e.g. "Gan"
    num_workers_loader = 1  # How many workers should load the data
    sample_save_size = 16   # How many numbers your saved imaged should show
    device = "cuda"         # Which device should be used to train the neural network
    model_load_path = ""    # If set load model instead of training from new
    num_epoch_log = 1       # How frequent you want to log/
    torch.manual_seed(43)   # Sets a seed for torch for reproducibility

    dataset_train, train_loader = get_MNIST_dataset(num_workers_loader, model_parameter)  # Get dataset

    # Initialize the Models and optimizer
    discriminator, generator, dis_opti, gen_opti = init_model_optimizer(model_parameter, device)  # Init model/Optimizer

    start_epoch = load_model(model_load_path)  # when we want to load a model

    # Init Logger
    logger = Logger(model_name, generator, discriminator, gen_opti, dis_opti, model_parameter, train_loader)

    loss = nn.BCELoss()

    images, labels = next(iter(train_loader))  # For logging

    # For testing
    # pred = generator(get_noise(model_parameter["batch_size"]).to(device), get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device))
    # dis = discriminator(images.to(device), get_hot_vector_encode(labels, device))

    logger.log_graph(get_noise(model_parameter["batch_size"]).to(device), images.to(device),
                     get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device),
                     get_hot_vector_encode(labels, device))


    # Array to store
    exp_replay = torch.tensor([]).to(device)

    for num_epoch in range(start_epoch, model_parameter["Epochs"]):
        for batch_num, data_loader in enumerate(train_loader):
            images, labels = data_loader
            images = add_noise_to_image(images, device)  # Add noise to the images

            # 1. Train Discriminator
            dis_error = train_discriminator(
                                            dis_opti,
                                            get_noise(model_parameter["batch_size"]).to(device),
                                            images.to(device),
                                            get_fake_data_target(model_parameter["batch_size"]).to(device),
                                            get_real_data_target(model_parameter["batch_size"]).to(device),
                                            get_hot_vector_encode(labels, device),
                                            get_hot_vector_encode(
                                                get_rand_labels(model_parameter["batch_size"]), device),
                                            device
                                            )

            # 2. Train Generator
            fake_image, pred_dis_fake, gen_error = train_generator(
                                                                  gen_opti,
                                                                  get_noise(model_parameter["batch_size"]).to(device),
                                                                  get_real_data_target(model_parameter["batch_size"]).to(device),
                                                                  get_hot_vector_encode(
                                                                      get_rand_labels(model_parameter["batch_size"]),
                                                                      device),
                                                                  device
                                                                  )


            # Store a random point for experience replay
            perm = torch.randperm(fake_image.size(0))
            r_idx = perm[:max(1, int(model_parameter["batch_size"] / r_frequent))]
            r_samples = add_noise_to_image(fake_image[r_idx], device)
            exp_replay = torch.cat((exp_replay, r_samples), 0).detach()

            if exp_replay.size(0) >= model_parameter["batch_size"]:
                # Train on experienced data
                dis_opti.zero_grad()

                r_label = get_hot_vector_encode(torch.zeros(exp_replay.size(0)).numpy(), device)
                pred_dis_real = discriminator(exp_replay, r_label)
                error_real = loss(pred_dis_real,  get_fake_data_target(exp_replay.size(0)).to(device))

                error_real.backward()

                dis_opti.step()

                print(f'Epoch: [{num_epoch}/{model_parameter["Epochs"]}] '
                      f'Batch: Replay/Experience batch '
                      f'Loss_D: {error_real.data.cpu()}, '
                      )

                exp_replay = torch.tensor([]).to(device)

            logger.display_stats(epoch=num_epoch, batch_num=batch_num, dis_error=dis_error, gen_error=gen_error)

            if batch_num % 100 == 0:
                logger.log_image(fake_image[:sample_save_size], num_epoch, batch_num)

        logger.log(num_epoch, dis_error, gen_error)
        if num_epoch % num_epoch_log == 0:
            logger.log_model(num_epoch)
            logger.log_histogramm()
    logger.close(logger, fake_image[:sample_save_size], num_epoch, dis_error, gen_error)

结论:

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import torch.nn.functional as F
from torch import optim as optim
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
import time


class Discriminator(torch.nn.Module):
    def __init__(self, ndf=16, dropout_value=0.5):  # ndf feature map discriminator
        super().__init__()
        self.ndf = ndf
        self.droupout_value = dropout_value

        self.condi = nn.Sequential(
            nn.Linear(in_features=10, out_features=64 * 64)
        )

        self.hidden0 = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
        )
        self.hidden1 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf, out_channels=self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden2 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 2, out_channels=self.ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden3 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 4, out_channels=self.ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.out = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            torch.nn.Sigmoid()
        )

    def forward(self, x, y):
        y = self.condi(y.view(-1, 10))
        y = y.view(-1, 1, 64, 64)

        x = torch.cat((x, y), dim=1)

        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Generator(torch.nn.Module):
    def __init__(self, n_features=100, ngf=16, c_channels=1, dropout_value=0.5):  # ngf feature map of generator
        super().__init__()
        self.ngf = ngf
        self.n_features = n_features
        self.c_channels = c_channels
        self.droupout_value = dropout_value

        self.hidden0 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.n_features + 10, out_channels=self.ngf * 8,
                               kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.LeakyReLU(0.2)
        )

        self.hidden1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 8, out_channels=self.ngf * 4,
                               kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ngf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 4, out_channels=self.ngf * 2,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 2, out_channels=self.ngf,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.out = nn.Sequential(
            # "out_channels=1" because gray scale
            nn.ConvTranspose2d(in_channels=self.ngf, out_channels=1, kernel_size=4,
                               stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x, y):
        x_cond = torch.cat((x, y), dim=1)  # Combine flatten image with conditional input (class labels)

        x = self.hidden0(x_cond)           # Image goes into a "ConvTranspose2d" layer
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Logger:
    def __init__(self, model_name, model1, model2, m1_optimizer, m2_optimizer, model_parameter, train_loader):
        self.out_dir = "data"
        self.model_name = model_name
        self.train_loader = train_loader
        self.model1 = model1
        self.model2 = model2
        self.model_parameter = model_parameter
        self.m1_optimizer = m1_optimizer
        self.m2_optimizer = m2_optimizer

        # Exclude Epochs of the model name. This make sense e.g. when we stop a training progress and continue later on.
        self.experiment_name = '_'.join("{!s}={!r}".format(k, v) for (k, v) in model_parameter.items())\
            .replace("Epochs" + "=" + str(model_parameter["Epochs"]), "")

        self.d_error = 0
        self.g_error = 0

        self.tb = SummaryWriter(log_dir=str(self.out_dir + "/log/" + self.model_name + "/runs/" + self.experiment_name))

        self.path_image = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/images/{self.experiment_name}')
        self.path_model = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/model/{self.experiment_name}')

        try:
            os.makedirs(self.path_image)
        except Exception as e:
            print("WARNING: ", str(e))

        try:
            os.makedirs(self.path_model)
        except Exception as e:
            print("WARNING: ", str(e))

    def log_graph(self, model1_input, model2_input, model1_label, model2_label):
        self.tb.add_graph(self.model1, input_to_model=(model1_input, model1_label))
        self.tb.add_graph(self.model2, input_to_model=(model2_input, model2_label))

    def log(self, num_epoch, d_error, g_error):
        self.d_error = d_error
        self.g_error = g_error

        self.tb.add_scalar("Discriminator Train Error", self.d_error, num_epoch)
        self.tb.add_scalar("Generator Train Error", self.g_error, num_epoch)

    def log_image(self, images, epoch, batch_num):
        grid = torchvision.utils.make_grid(images)
        torchvision.utils.save_image(grid, f'{self.path_image}\\Epoch_{epoch}_batch_{batch_num}.png')

        self.tb.add_image("Generator Image", grid)

    def log_histogramm(self):
        for name, param in self.model2.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'gen_{name}.grad', param.grad, self.model_parameter["Epochs"])

        for name, param in self.model1.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'dis_{name}.grad', param.grad, self.model_parameter["Epochs"])

    def log_model(self, num_epoch):
        torch.save({
            "epoch": num_epoch,
            "model_generator_state_dict": self.model1.state_dict(),
            "model_discriminator_state_dict": self.model2.state_dict(),
            "optimizer_generator_state_dict":  self.m1_optimizer.state_dict(),
            "optimizer_discriminator_state_dict":  self.m2_optimizer.state_dict(),
        }, str(self.path_model + f'\\{time.time()}_epoch{num_epoch}.pth'))

    def close(self, logger, images, num_epoch,  d_error, g_error):
        logger.log_model(num_epoch)
        logger.log_histogramm()
        logger.log(num_epoch, d_error, g_error)
        self.tb.close()

    def display_stats(self, epoch, batch_num, dis_error, gen_error):
        print(f'Epoch: [{epoch}/{self.model_parameter["Epochs"]}] '
              f'Batch: [{batch_num}/{len(self.train_loader)}] '
              f'Loss_D: {dis_error.data.cpu()}, '
              f'Loss_G: {gen_error.data.cpu()}')


def get_MNIST_dataset(num_workers_loader, model_parameter, out_dir="data"):
    compose = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.CenterCrop((64, 64)),
        transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    dataset = datasets.MNIST(
        root=out_dir,
        train=True,
        download=True,
        transform=compose
    )

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=model_parameter["batch_size"],
                                               num_workers=num_workers_loader,
                                               shuffle=model_parameter["shuffle"])

    return dataset, train_loader


def train_discriminator(p_optimizer, p_noise, p_images, p_fake_target, p_real_target, p_images_labels, p_fake_labels, device):
    p_optimizer.zero_grad()

    # 1.1 Train on real data
    pred_dis_real = discriminator(p_images, p_images_labels)
    error_real = loss(pred_dis_real, p_real_target)

    error_real.backward()

    # 1.2 Train on fake data
    fake_data = generator(p_noise, p_fake_labels).detach()
    fake_data = add_noise_to_image(fake_data, device)
    pred_dis_fake = discriminator(fake_data, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_fake_target)

    error_fake.backward()

    p_optimizer.step()

    return error_fake + error_real


def train_generator(p_optimizer, p_noise, p_real_target, p_fake_labels, device):
    p_optimizer.zero_grad()

    fake_images = generator(p_noise, p_fake_labels)
    fake_images = add_noise_to_image(fake_images, device)
    pred_dis_fake = discriminator(fake_images, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_real_target)  # because
    """
    We use "p_real_target" instead of "p_fake_target" because we want to 
    maximize that the discriminator is wrong.
    """

    error_fake.backward()

    p_optimizer.step()

    return fake_images, pred_dis_fake, error_fake


# TODO change to a Truncated normal distribution
def get_noise(batch_size, n_features=100):
    return torch.FloatTensor(batch_size, n_features, 1, 1).uniform_(-1, 1)


# We flip label of real and fate data. Better gradient flow I have told
def get_real_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.0, 0.2)


def get_fake_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.8, 1.1)


def image_to_vector(images):
    return torch.flatten(images, start_dim=1, end_dim=-1)


def vector_to_image(images):
    return images.view(images.size(0), 1, 28, 28)


def get_rand_labels(batch_size):
    return torch.randint(low=0, high=9, size=(batch_size,))


def load_model(model_load_path):
    if model_load_path:
        checkpoint = torch.load(model_load_path)

        discriminator.load_state_dict(checkpoint["model_discriminator_state_dict"])
        generator.load_state_dict(checkpoint["model_generator_state_dict"])

        dis_opti.load_state_dict(checkpoint["optimizer_discriminator_state_dict"])
        gen_opti.load_state_dict(checkpoint["optimizer_generator_state_dict"])

        return checkpoint["epoch"]

    else:
        return 0


def init_model_optimizer(model_parameter, device):
    # Initialize the Models
    discriminator = Discriminator(ndf=model_parameter["ndf"], dropout_value=model_parameter["dropout"]).to(device)
    generator = Generator(ngf=model_parameter["ngf"], dropout_value=model_parameter["dropout"]).to(device)

    # train
    dis_opti = optim.Adam(discriminator.parameters(), lr=model_parameter["learning_rate_dis"], betas=(0.5, 0.999))
    gen_opti = optim.Adam(generator.parameters(), lr=model_parameter["learning_rate_gen"], betas=(0.5, 0.999))

    return discriminator, generator, dis_opti, gen_opti


def get_hot_vector_encode(labels, device):
    return torch.eye(10)[labels].view(-1, 10, 1, 1).to(device)


def add_noise_to_image(images, device, level_of_noise=0.1):
    return images[0].to(device) + (level_of_noise) * torch.randn(images.shape).to(device)


if __name__ == "__main__":
    # Hyperparameter
    model_parameter = {
        "batch_size": 500,
        "learning_rate_dis": 0.0002,
        "learning_rate_gen": 0.0002,
        "shuffle": False,
        "Epochs": 10,
        "ndf": 64,
        "ngf": 64,
        "dropout": 0.5
    }

    # Parameter
    r_frequent = 10        # How many samples we save for replay per batch (batch_size / r_frequent).
    model_name = "CDCGAN"   # The name of you model e.g. "Gan"
    num_workers_loader = 1  # How many workers should load the data
    sample_save_size = 16   # How many numbers your saved imaged should show
    device = "cuda"         # Which device should be used to train the neural network
    model_load_path = ""    # If set load model instead of training from new
    num_epoch_log = 1       # How frequent you want to log/
    torch.manual_seed(43)   # Sets a seed for torch for reproducibility

    dataset_train, train_loader = get_MNIST_dataset(num_workers_loader, model_parameter)  # Get dataset

    # Initialize the Models and optimizer
    discriminator, generator, dis_opti, gen_opti = init_model_optimizer(model_parameter, device)  # Init model/Optimizer

    start_epoch = load_model(model_load_path)  # when we want to load a model

    # Init Logger
    logger = Logger(model_name, generator, discriminator, gen_opti, dis_opti, model_parameter, train_loader)

    loss = nn.BCELoss()

    images, labels = next(iter(train_loader))  # For logging

    # For testing
    # pred = generator(get_noise(model_parameter["batch_size"]).to(device), get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device))
    # dis = discriminator(images.to(device), get_hot_vector_encode(labels, device))

    logger.log_graph(get_noise(model_parameter["batch_size"]).to(device), images.to(device),
                     get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device),
                     get_hot_vector_encode(labels, device))


    # Array to store
    exp_replay = torch.tensor([]).to(device)

    for num_epoch in range(start_epoch, model_parameter["Epochs"]):
        for batch_num, data_loader in enumerate(train_loader):
            images, labels = data_loader
            images = add_noise_to_image(images, device)  # Add noise to the images

            # 1. Train Discriminator
            dis_error = train_discriminator(
                                            dis_opti,
                                            get_noise(model_parameter["batch_size"]).to(device),
                                            images.to(device),
                                            get_fake_data_target(model_parameter["batch_size"]).to(device),
                                            get_real_data_target(model_parameter["batch_size"]).to(device),
                                            get_hot_vector_encode(labels, device),
                                            get_hot_vector_encode(
                                                get_rand_labels(model_parameter["batch_size"]), device),
                                            device
                                            )

            # 2. Train Generator
            fake_image, pred_dis_fake, gen_error = train_generator(
                                                                  gen_opti,
                                                                  get_noise(model_parameter["batch_size"]).to(device),
                                                                  get_real_data_target(model_parameter["batch_size"]).to(device),
                                                                  get_hot_vector_encode(
                                                                      get_rand_labels(model_parameter["batch_size"]),
                                                                      device),
                                                                  device
                                                                  )


            # Store a random point for experience replay
            perm = torch.randperm(fake_image.size(0))
            r_idx = perm[:max(1, int(model_parameter["batch_size"] / r_frequent))]
            r_samples = add_noise_to_image(fake_image[r_idx], device)
            exp_replay = torch.cat((exp_replay, r_samples), 0).detach()

            if exp_replay.size(0) >= model_parameter["batch_size"]:
                # Train on experienced data
                dis_opti.zero_grad()

                r_label = get_hot_vector_encode(torch.zeros(exp_replay.size(0)).numpy(), device)
                pred_dis_real = discriminator(exp_replay, r_label)
                error_real = loss(pred_dis_real,  get_fake_data_target(exp_replay.size(0)).to(device))

                error_real.backward()

                dis_opti.step()

                print(f'Epoch: [{num_epoch}/{model_parameter["Epochs"]}] '
                      f'Batch: Replay/Experience batch '
                      f'Loss_D: {error_real.data.cpu()}, '
                      )

                exp_replay = torch.tensor([]).to(device)

            logger.display_stats(epoch=num_epoch, batch_num=batch_num, dis_error=dis_error, gen_error=gen_error)

            if batch_num % 100 == 0:
                logger.log_image(fake_image[:sample_save_size], num_epoch, batch_num)

        logger.log(num_epoch, dis_error, gen_error)
        if num_epoch % num_epoch_log == 0:
            logger.log_model(num_epoch)
            logger.log_histogramm()
    logger.close(logger, fake_image[:sample_save_size], num_epoch, dis_error, gen_error)
因为我实现了所有这些被认为对GAN/DCGAN有益的事情(例如标签平滑)。
我的模型的性能仍然比PyTorch的教程DCGAN差,我想我的代码中可能有一个bug,但我似乎找不到它

再现性:

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import torch.nn.functional as F
from torch import optim as optim
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import os
import time


class Discriminator(torch.nn.Module):
    def __init__(self, ndf=16, dropout_value=0.5):  # ndf feature map discriminator
        super().__init__()
        self.ndf = ndf
        self.droupout_value = dropout_value

        self.condi = nn.Sequential(
            nn.Linear(in_features=10, out_features=64 * 64)
        )

        self.hidden0 = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
        )
        self.hidden1 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf, out_channels=self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden2 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 2, out_channels=self.ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.hidden3 = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 4, out_channels=self.ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )
        self.out = nn.Sequential(
            nn.Conv2d(in_channels=self.ndf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            torch.nn.Sigmoid()
        )

    def forward(self, x, y):
        y = self.condi(y.view(-1, 10))
        y = y.view(-1, 1, 64, 64)

        x = torch.cat((x, y), dim=1)

        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Generator(torch.nn.Module):
    def __init__(self, n_features=100, ngf=16, c_channels=1, dropout_value=0.5):  # ngf feature map of generator
        super().__init__()
        self.ngf = ngf
        self.n_features = n_features
        self.c_channels = c_channels
        self.droupout_value = dropout_value

        self.hidden0 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.n_features + 10, out_channels=self.ngf * 8,
                               kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.LeakyReLU(0.2)
        )

        self.hidden1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 8, out_channels=self.ngf * 4,
                               kernel_size=4, stride=2, padding=1, bias=False),
            #nn.BatchNorm2d(self.ngf * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 4, out_channels=self.ngf * 2,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.hidden3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.ngf * 2, out_channels=self.ngf,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.LeakyReLU(0.2),
            nn.Dropout(self.droupout_value)
        )

        self.out = nn.Sequential(
            # "out_channels=1" because gray scale
            nn.ConvTranspose2d(in_channels=self.ngf, out_channels=1, kernel_size=4,
                               stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x, y):
        x_cond = torch.cat((x, y), dim=1)  # Combine flatten image with conditional input (class labels)

        x = self.hidden0(x_cond)           # Image goes into a "ConvTranspose2d" layer
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)

        return x


class Logger:
    def __init__(self, model_name, model1, model2, m1_optimizer, m2_optimizer, model_parameter, train_loader):
        self.out_dir = "data"
        self.model_name = model_name
        self.train_loader = train_loader
        self.model1 = model1
        self.model2 = model2
        self.model_parameter = model_parameter
        self.m1_optimizer = m1_optimizer
        self.m2_optimizer = m2_optimizer

        # Exclude Epochs of the model name. This make sense e.g. when we stop a training progress and continue later on.
        self.experiment_name = '_'.join("{!s}={!r}".format(k, v) for (k, v) in model_parameter.items())\
            .replace("Epochs" + "=" + str(model_parameter["Epochs"]), "")

        self.d_error = 0
        self.g_error = 0

        self.tb = SummaryWriter(log_dir=str(self.out_dir + "/log/" + self.model_name + "/runs/" + self.experiment_name))

        self.path_image = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/images/{self.experiment_name}')
        self.path_model = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/model/{self.experiment_name}')

        try:
            os.makedirs(self.path_image)
        except Exception as e:
            print("WARNING: ", str(e))

        try:
            os.makedirs(self.path_model)
        except Exception as e:
            print("WARNING: ", str(e))

    def log_graph(self, model1_input, model2_input, model1_label, model2_label):
        self.tb.add_graph(self.model1, input_to_model=(model1_input, model1_label))
        self.tb.add_graph(self.model2, input_to_model=(model2_input, model2_label))

    def log(self, num_epoch, d_error, g_error):
        self.d_error = d_error
        self.g_error = g_error

        self.tb.add_scalar("Discriminator Train Error", self.d_error, num_epoch)
        self.tb.add_scalar("Generator Train Error", self.g_error, num_epoch)

    def log_image(self, images, epoch, batch_num):
        grid = torchvision.utils.make_grid(images)
        torchvision.utils.save_image(grid, f'{self.path_image}\\Epoch_{epoch}_batch_{batch_num}.png')

        self.tb.add_image("Generator Image", grid)

    def log_histogramm(self):
        for name, param in self.model2.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'gen_{name}.grad', param.grad, self.model_parameter["Epochs"])

        for name, param in self.model1.named_parameters():
            self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
            self.tb.add_histogram(f'dis_{name}.grad', param.grad, self.model_parameter["Epochs"])

    def log_model(self, num_epoch):
        torch.save({
            "epoch": num_epoch,
            "model_generator_state_dict": self.model1.state_dict(),
            "model_discriminator_state_dict": self.model2.state_dict(),
            "optimizer_generator_state_dict":  self.m1_optimizer.state_dict(),
            "optimizer_discriminator_state_dict":  self.m2_optimizer.state_dict(),
        }, str(self.path_model + f'\\{time.time()}_epoch{num_epoch}.pth'))

    def close(self, logger, images, num_epoch,  d_error, g_error):
        logger.log_model(num_epoch)
        logger.log_histogramm()
        logger.log(num_epoch, d_error, g_error)
        self.tb.close()

    def display_stats(self, epoch, batch_num, dis_error, gen_error):
        print(f'Epoch: [{epoch}/{self.model_parameter["Epochs"]}] '
              f'Batch: [{batch_num}/{len(self.train_loader)}] '
              f'Loss_D: {dis_error.data.cpu()}, '
              f'Loss_G: {gen_error.data.cpu()}')


def get_MNIST_dataset(num_workers_loader, model_parameter, out_dir="data"):
    compose = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.CenterCrop((64, 64)),
        transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    dataset = datasets.MNIST(
        root=out_dir,
        train=True,
        download=True,
        transform=compose
    )

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=model_parameter["batch_size"],
                                               num_workers=num_workers_loader,
                                               shuffle=model_parameter["shuffle"])

    return dataset, train_loader


def train_discriminator(p_optimizer, p_noise, p_images, p_fake_target, p_real_target, p_images_labels, p_fake_labels, device):
    p_optimizer.zero_grad()

    # 1.1 Train on real data
    pred_dis_real = discriminator(p_images, p_images_labels)
    error_real = loss(pred_dis_real, p_real_target)

    error_real.backward()

    # 1.2 Train on fake data
    fake_data = generator(p_noise, p_fake_labels).detach()
    fake_data = add_noise_to_image(fake_data, device)
    pred_dis_fake = discriminator(fake_data, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_fake_target)

    error_fake.backward()

    p_optimizer.step()

    return error_fake + error_real


def train_generator(p_optimizer, p_noise, p_real_target, p_fake_labels, device):
    p_optimizer.zero_grad()

    fake_images = generator(p_noise, p_fake_labels)
    fake_images = add_noise_to_image(fake_images, device)
    pred_dis_fake = discriminator(fake_images, p_fake_labels)
    error_fake = loss(pred_dis_fake, p_real_target)  # because
    """
    We use "p_real_target" instead of "p_fake_target" because we want to 
    maximize that the discriminator is wrong.
    """

    error_fake.backward()

    p_optimizer.step()

    return fake_images, pred_dis_fake, error_fake


# TODO change to a Truncated normal distribution
def get_noise(batch_size, n_features=100):
    return torch.FloatTensor(batch_size, n_features, 1, 1).uniform_(-1, 1)


# We flip label of real and fate data. Better gradient flow I have told
def get_real_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.0, 0.2)


def get_fake_data_target(batch_size):
    return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.8, 1.1)


def image_to_vector(images):
    return torch.flatten(images, start_dim=1, end_dim=-1)


def vector_to_image(images):
    return images.view(images.size(0), 1, 28, 28)


def get_rand_labels(batch_size):
    return torch.randint(low=0, high=9, size=(batch_size,))


def load_model(model_load_path):
    if model_load_path:
        checkpoint = torch.load(model_load_path)

        discriminator.load_state_dict(checkpoint["model_discriminator_state_dict"])
        generator.load_state_dict(checkpoint["model_generator_state_dict"])

        dis_opti.load_state_dict(checkpoint["optimizer_discriminator_state_dict"])
        gen_opti.load_state_dict(checkpoint["optimizer_generator_state_dict"])

        return checkpoint["epoch"]

    else:
        return 0


def init_model_optimizer(model_parameter, device):
    # Initialize the Models
    discriminator = Discriminator(ndf=model_parameter["ndf"], dropout_value=model_parameter["dropout"]).to(device)
    generator = Generator(ngf=model_parameter["ngf"], dropout_value=model_parameter["dropout"]).to(device)

    # train
    dis_opti = optim.Adam(discriminator.parameters(), lr=model_parameter["learning_rate_dis"], betas=(0.5, 0.999))
    gen_opti = optim.Adam(generator.parameters(), lr=model_parameter["learning_rate_gen"], betas=(0.5, 0.999))

    return discriminator, generator, dis_opti, gen_opti


def get_hot_vector_encode(labels, device):
    return torch.eye(10)[labels].view(-1, 10, 1, 1).to(device)


def add_noise_to_image(images, device, level_of_noise=0.1):
    return images[0].to(device) + (level_of_noise) * torch.randn(images.shape).to(device)


if __name__ == "__main__":
    # Hyperparameter
    model_parameter = {
        "batch_size": 500,
        "learning_rate_dis": 0.0002,
        "learning_rate_gen": 0.0002,
        "shuffle": False,
        "Epochs": 10,
        "ndf": 64,
        "ngf": 64,
        "dropout": 0.5
    }

    # Parameter
    r_frequent = 10        # How many samples we save for replay per batch (batch_size / r_frequent).
    model_name = "CDCGAN"   # The name of you model e.g. "Gan"
    num_workers_loader = 1  # How many workers should load the data
    sample_save_size = 16   # How many numbers your saved imaged should show
    device = "cuda"         # Which device should be used to train the neural network
    model_load_path = ""    # If set load model instead of training from new
    num_epoch_log = 1       # How frequent you want to log/
    torch.manual_seed(43)   # Sets a seed for torch for reproducibility

    dataset_train, train_loader = get_MNIST_dataset(num_workers_loader, model_parameter)  # Get dataset

    # Initialize the Models and optimizer
    discriminator, generator, dis_opti, gen_opti = init_model_optimizer(model_parameter, device)  # Init model/Optimizer

    start_epoch = load_model(model_load_path)  # when we want to load a model

    # Init Logger
    logger = Logger(model_name, generator, discriminator, gen_opti, dis_opti, model_parameter, train_loader)

    loss = nn.BCELoss()

    images, labels = next(iter(train_loader))  # For logging

    # For testing
    # pred = generator(get_noise(model_parameter["batch_size"]).to(device), get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device))
    # dis = discriminator(images.to(device), get_hot_vector_encode(labels, device))

    logger.log_graph(get_noise(model_parameter["batch_size"]).to(device), images.to(device),
                     get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device),
                     get_hot_vector_encode(labels, device))


    # Array to store
    exp_replay = torch.tensor([]).to(device)

    for num_epoch in range(start_epoch, model_parameter["Epochs"]):
        for batch_num, data_loader in enumerate(train_loader):
            images, labels = data_loader
            images = add_noise_to_image(images, device)  # Add noise to the images

            # 1. Train Discriminator
            dis_error = train_discriminator(
                                            dis_opti,
                                            get_noise(model_parameter["batch_size"]).to(device),
                                            images.to(device),
                                            get_fake_data_target(model_parameter["batch_size"]).to(device),
                                            get_real_data_target(model_parameter["batch_size"]).to(device),
                                            get_hot_vector_encode(labels, device),
                                            get_hot_vector_encode(
                                                get_rand_labels(model_parameter["batch_size"]), device),
                                            device
                                            )

            # 2. Train Generator
            fake_image, pred_dis_fake, gen_error = train_generator(
                                                                  gen_opti,
                                                                  get_noise(model_parameter["batch_size"]).to(device),
                                                                  get_real_data_target(model_parameter["batch_size"]).to(device),
                                                                  get_hot_vector_encode(
                                                                      get_rand_labels(model_parameter["batch_size"]),
                                                                      device),
                                                                  device
                                                                  )


            # Store a random point for experience replay
            perm = torch.randperm(fake_image.size(0))
            r_idx = perm[:max(1, int(model_parameter["batch_size"] / r_frequent))]
            r_samples = add_noise_to_image(fake_image[r_idx], device)
            exp_replay = torch.cat((exp_replay, r_samples), 0).detach()

            if exp_replay.size(0) >= model_parameter["batch_size"]:
                # Train on experienced data
                dis_opti.zero_grad()

                r_label = get_hot_vector_encode(torch.zeros(exp_replay.size(0)).numpy(), device)
                pred_dis_real = discriminator(exp_replay, r_label)
                error_real = loss(pred_dis_real,  get_fake_data_target(exp_replay.size(0)).to(device))

                error_real.backward()

                dis_opti.step()

                print(f'Epoch: [{num_epoch}/{model_parameter["Epochs"]}] '
                      f'Batch: Replay/Experience batch '
                      f'Loss_D: {error_real.data.cpu()}, '
                      )

                exp_replay = torch.tensor([]).to(device)

            logger.display_stats(epoch=num_epoch, batch_num=batch_num, dis_error=dis_error, gen_error=gen_error)

            if batch_num % 100 == 0:
                logger.log_image(fake_image[:sample_save_size], num_epoch, batch_num)

        logger.log(num_epoch, dis_error, gen_error)
        if num_epoch % num_epoch_log == 0:
            logger.log_model(num_epoch)
            logger.log_histogramm()
    logger.close(logger, fake_image[:sample_save_size], num_epoch, dis_error, gen_error)
如果您安装了我导入的库,您应该能够复制代码并运行它,以便自己查找,如果您能找到任何东西的话


非常感谢您的反馈。

所以我不久前解决了这个问题,但忘了在堆栈溢出问题上发布答案。因此,我将简单地在这里发布我的代码,这可能会很好地工作。 一些免责声明:

  • 我不太确定它是否有效,因为我一年前就这样做了
  • 它用于128x128px图像MNIST
  • 这不是一个简单的问题,我使用了各种优化技术
  • 如果要使用它,需要更改各种详细信息,例如培训数据集
资源:

``


``

您是否同时重新实现了代码并添加到模型中?您应该重新实现教程中完全相同的GAN版本,然后测试它,然后如果它有效,则使标签平滑addition@TiagoMartinsPeres不完全是。我希望有经验的人能够识别图片中的模式和相应的错误,或者建议如何调试它,因为调试神经网络相对困难,我不知道如何调试。@ThomaS I正在研究这个问题。这个问题不是更适合数据科学堆栈交换吗?你的结果看起来不错。我的意思是他们是数字型的,这正是我所期望的。Pytorch版本看起来只是训练了更长时间。正如@ThomaS指出的,这可能是由于模型的变化。Pytorch版本将根据对其有效的方式进行优化,任何偏离这一点都将恶化结果。