gpt4 book ai didi

pytorch - pytorch 代码中的 KL 散度与公式有何关系?

转载 作者:行者123 更新时间:2023-12-05 01:36:58 35 4
gpt4 key购买 nike

在 VAE 教程中,两个正态分布的 kl-divergence 定义为: enter image description here

而且在很多代码中,比如here , herehere ,代码实现为:

 KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())

def latent_loss(z_mean, z_stddev):
mean_sq = z_mean * z_mean
stddev_sq = z_stddev * z_stddev
return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)

它们有什么关系?为什么代码中没有任何“tr”或“.transpose()”?

最佳答案

您发布的代码中的表达式假定 X 是一个不相关 多变量高斯随机变量。这在协方差矩阵的行列式中缺少交叉项是显而易见的。因此均值向量和协方差矩阵的形式为

enter image description here

使用它我们可以快速推导出以下原始表达式组件的等价表示

enter image description here

将这些代回原始表达式得到

enter image description here

关于pytorch - pytorch 代码中的 KL 散度与公式有何关系?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61597340/

35 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com