Python 条件WGAN-GP在Keras中的实现

Python 条件WGAN-GP在Keras中的实现,python,keras,generative-adversarial-network,Python,Keras,Generative Adversarial Network,我将WGAN-GP扩展为有条件的代码库,可以在这里找到: 当我训练模型时,它似乎不受标签的限制。这就是我构建模型的方式 # The generator takes noise and the target label (states) as input # and generates the corresponding samples of that label noise = Input(shape=(self.latent_size, ), name="noise"

我将WGAN-GP扩展为有条件的代码库,可以在这里找到:

当我训练模型时,它似乎不受标签的限制。这就是我构建模型的方式

    # The generator takes noise and the target label (states) as input
    # and generates the corresponding samples of that label
    noise = Input(shape=(self.latent_size, ), name="noise")
    label = Input(shape=(self.label_size, ), name="labels")
    real_samples = Input(shape=(self.input_size,), name="real")

    self.discriminator = self.build_discriminator()
    self.generator = self.build_generator([noise, label])

    # First we train the discriminator
    self.generator.trainable = False
    fake_samples = self.generator([noise, label])

    fake = self.discriminator([fake_samples, label])
    valid = self.discriminator([real_samples, label])

    interpolated = Lambda(self.random_weighted_average)([real_samples, fake_samples])
    valid_interp = self.discriminator([interpolated, label])

    self.d_model = Model([real_samples, noise, label],
                         [valid, fake, valid_interp],
                         name="discriminator")

    # Time to train the generator
    self.discriminator.trainable = False
    self.generator.trainable = True

    noise_gen = Input(shape=(self.latent_size,), name="noise_gen")

    fake_samples = self.generator([noise_gen, label])
    valid = self.discriminator([fake_samples, label])

    self.g_model = Model([noise_gen, label], valid, name="generator")
    self.g_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)
绘制模型将导致:

我不知道如何解释右边的合并箭头。标签应连接到鉴别器中。我觉得这句话把事情搞砸了:

    self.d_model = Model([real_samples, noise, label],
                         [valid, fake, valid_interp],
                         name="discriminator")
因为我只是传递标签,我不知道Keras如何将输入路由到其他输出