gpt4 book ai didi

tensorflow - TensorFlow Probability 的贝叶斯层的属性(property)损失代表什么?

转载 作者:行者123 更新时间:2023-12-01 23:21:08 24 4
gpt4 key购买 nike

我正在 Bayesian Neural Network 上运行示例代码使用 Tensorflow Probability 实现。

我的问题是关于用于变分推理的 ELBO 损失的实现。 ELBO 等于两项的总和,即代码中实现的 neg_log_likelihoodkl。我很难理解 kl 术语的实现。

以下是模型的定义方式:

with tf.name_scope("bayesian_neural_net", values=[images]):
neural_net = tf.keras.Sequential()
for units in FLAGS.layer_sizes:
layer = tfp.layers.DenseFlipout(units, activation=FLAGS.activation)
neural_net.add(layer)
neural_net.add(tfp.layers.DenseFlipout(10))
logits = neural_net(images)
labels_distribution = tfd.Categorical(logits=logits)

以下是“kl”术语的定义:

kl = sum(neural_net.losses) / mnist_data.train.num_examples

我不确定这里返回的是什么neural_net.losses,因为没有为neural_net定义损失函数。显然,neural_net.losses会返回一些值,但我不知道返回值的含义是什么。对此有何评论?

我的猜测是 L2 范数,但我不确定。如果是这样的话,我们仍然缺少一些东西。根据VAE论文附录 B,作者在先验为标准正态时导出了 KL 项。事实证明,除了额外的对数方差项和常数项之外,它非常接近变分参数的 L2 范数。对此有何评论?

最佳答案

TensorFlow Keras Layerlosses 属性表示副作用计算,例如正则化器惩罚。与特定 TensorFlow 变量的正则化惩罚不同,这里的损失表示 KL 散度计算。查看implementation here以及docstring's example :

We illustrate a Bayesian neural network with variational inference, assuming a dataset of features and labels.

  import tensorflow_probability as tfp
model = tf.keras.Sequential([
tfp.layers.DenseFlipout(512, activation=tf.nn.relu),
tfp.layers.DenseFlipout(10),
])
logits = model(features)
neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
kl = sum(model.losses)
loss = neg_log_likelihood + kl
train_op = tf.train.AdamOptimizer().minimize(loss)

It uses the Flipout gradient estimator to minimize the Kullback-Leibler divergence up to a constant, also known as the negative Evidence Lower Bound. It consists of the sum of two terms: the expected negative log-likelihood, which we approximate via Monte Carlo; and the KL divergence, which is added via regularizer terms which are arguments to the layer.

关于tensorflow - TensorFlow Probability 的贝叶斯层的属性(property)损失代表什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50064792/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com