Machine learning PyTorch中生成性对抗网络(GAN)的训练生成器

Machine learning PyTorch中生成性对抗网络(GAN)的训练生成器,machine-learning,deep-learning,pytorch,generative-adversarial-network,Machine Learning,Deep Learning,Pytorch,Generative Adversarial Network,我正在PyTorch 1.5.0中实现生成性对抗网络(GAN) 为了计算生成器的损失,我计算了鉴别器错误分类全真实小批量和全(生成器生成的)假小批量的负概率。然后,我依次反向传播这两个部分,最后应用step函数 计算并反向传播损失部分,该部分是生成的伪数据的错误分类的函数,这似乎是直截了当的,因为在该损失项的反向传播过程中,反向路径通过首先生成伪数据的生成器 但是,所有真实数据小批量的分类不涉及通过生成器传递数据。因此,我想知道以下被剪掉的代码是否仍然会计算生成器的梯度,或者它是否根本不会计算任

我正在PyTorch 1.5.0中实现生成性对抗网络(GAN)

为了计算生成器的损失,我计算了鉴别器错误分类全真实小批量和全(生成器生成的)假小批量的负概率。然后,我依次反向传播这两个部分,最后应用step函数

计算并反向传播损失部分,该部分是生成的伪数据的错误分类的函数,这似乎是直截了当的,因为在该损失项的反向传播过程中,反向路径通过首先生成伪数据的生成器

但是,所有真实数据小批量的分类不涉及通过生成器传递数据。因此,我想知道以下被剪掉的代码是否仍然会计算生成器的梯度,或者它是否根本不会计算任何梯度(因为反向路径不会穿过生成器,并且在更新生成器时鉴别器处于eval模式)


如果这不能按预期工作,我如何才能使其工作?提前谢谢

没有梯度传播到发生器,因为没有对发生器的任何参数进行计算。处于eval模式的鉴别器不会阻止渐变传播到生成器,尽管如果使用的层在eval模式下的行为与在train模式下的行为不同,例如“dropout”,则渐变会略有不同

真实图像的错误分类不是训练生成器的一部分,因为它不会从这些信息中获得任何信息。从概念上讲,如果鉴别器未能正确地对真实图像进行分类,那么生成器应该从这个事实中学到什么?生成器的唯一任务是创建假图像,使鉴别器认为它是真实的,因此生成器的唯一相关信息是鉴别器是否能够识别假图像。如果鉴别器确实能够识别假图像,则生成器需要调整自身以创建更具说服力的假图像


当然,这不是一个二进制的情况,但生成器总是试图改进假图像,以便鉴别器更确信这是一个真实的图像。生成器的目标不是使鉴别器可疑(概率为0.5,它是真的还是假的),而是使鉴别器完全相信它是真的,即使它是假的。这就是为什么它们是对抗性的,而不是合作性的。

我正在扩展以下研究论文中提出的一些方法:在本文中,鉴别器对真实图像和虚假图像的混淆是生成器损失函数的一部分(等式4)。所以,我仍然想知道如何通过发电机反向传播损耗项,而发电机本身并没有参与计算,我认为第二项不应该存在。为了对称,我可以看到它,但只有第一项对W有影响,第二项实际上是常数W.r.t.W,使其导数为0。如果嵌入也得到优化(即“生成器”的一部分),则损失是有意义的,但事实并非如此。这段代码非常古老,使用了过时的PyTorch结构,但您仍然可以看到它们所做的事情。事实上,这段代码表明它们对这两个术语都进行了优化。在第75行中,他们是concat。真目标和假翻译嵌入,在第116行计算损失w.r.t。这两个术语和反支柱通过其生成器
w
在128中计算组合的总损失。因此,在我看来,他们在生成器更新中考虑目标词混淆的方法是简单地添加两个词并反向传播组合/合计损失。你对上述代码的理解是正确的:没有任何东西是靠后支撑的(检查了靠后支撑步骤前后的坡度之和)。是的,但是将这些术语加在一起对W没有任何影响,就像在
(损失+10)中一样。后退()
10没有影响,最好不要使用。不管它是10还是f(10),只要f在计算中不包括W,它就像一个常数。在串联的情况下,它们是批次的独立样本,彼此独立,因此对第一项的唯一实际影响是取平均值时,样本数是两倍。我假设他们最初计划利用这一损失优化嵌入。
# Update generator #
net.generator.train()
net.discriminator.eval()
net.generator.zero_grad()

# All-real minibatch
x_real = get_all_real_minibatch()
y_true = torch.full((batch_size,), label_fake).long()  # Pretend true targets were fake
y_pred = net.discriminator(x_real)  # Produces softmax probability distribution over (0=label_fake,1=label_real)

loss_real = NLLLoss(torch.log(y_pred), y_true) 
loss_real.backward()
optimizer_generator.step()