pytorch代码中的KL散度与公式有何关系?

pytorch代码中的KL散度与公式有何关系?,pytorch,autoencoder,loss-function,Pytorch,Autoencoder,Loss Function,在VAE教程中,两个正态分布的kl散度定义为: 在许多代码中,如和,代码的实现方式如下: KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp()) 或 它们有什么关系?为什么代码中没有任何“tr”或“.transpose()”?您发布的代码中的表达式假设X是一个不相关的多变量高斯随机变量。协方差矩阵行列式中缺少交叉项,这一点很明显。因此,平均向量和协方差矩阵的形式如下 使用此方法,我们可以快速导出原始表达式组件的以下等

在VAE教程中,两个正态分布的kl散度定义为:

在许多代码中,如和,代码的实现方式如下:

 KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())


它们有什么关系?为什么代码中没有任何“tr”或“.transpose()”?

您发布的代码中的表达式假设X是一个不相关的多变量高斯随机变量。协方差矩阵行列式中缺少交叉项,这一点很明显。因此,平均向量和协方差矩阵的形式如下

使用此方法,我们可以快速导出原始表达式组件的以下等效表示

将这些替换回原始表达式中,会得到


这是由Kingma()在附录B中的原始VAE文件中得出的。注意,第二个版本中有一个额外的缩放,它使用
torch.mean
而不是
torch.sum
,这不是一个问题,因为缩放不会改变最佳点(尽管这可能意味着您需要不同的学习速率)@jodag非常有帮助,thanks@jodag关于torch.sum和torch.mean,你说“这可能意味着你需要不同的学习速度”,但KL损失不是唯一的损失术语,损失=KL_损失+侦察损失,这是否意味着损失实际上是具有不同权重的加权和?是的,如果使用平均值而不是总和,则kl_损失分量的权重将隐含地小于原始公式,这可能会影响损失函数的最佳点,并可能影响最终结果。
def latent_loss(z_mean, z_stddev):
    mean_sq = z_mean * z_mean
    stddev_sq = z_stddev * z_stddev
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)