gpt4 book ai didi

python - WGAN-GP 列车损失较大

转载 作者:行者123 更新时间:2023-12-01 08:43:15 26 4
gpt4 key购买 nike

这是WGAN-GP的损失函数

gen_sample = model.generator(input_gen)
disc_real = model.discriminator(real_image, reuse=False)
disc_fake = model.discriminator(gen_sample, reuse=True)
disc_concat = tf.concat([disc_real, disc_fake], axis=0)
# Gradient penalty
alpha = tf.random_uniform(
shape=[BATCH_SIZE, 1, 1, 1],
minval=0.,
maxval=1.)
differences = gen_sample - real_image
interpolates = real_image + (alpha * differences)
gradients = tf.gradients(model.discriminator(interpolates, reuse=True), [interpolates])[0] # why [0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)

d_loss_real = tf.reduce_mean(disc_real)
d_loss_fake = tf.reduce_mean(disc_fake)

disc_loss = -(d_loss_real - d_loss_fake) + LAMBDA * gradient_penalty
gen_loss = - d_loss_fake

This is the training loss

发电机损耗是振荡的,而且数值很大。我的问题是:发电机损耗是正常还是异常?

最佳答案

需要注意的一点是你的梯度惩罚计算是错误的。以下行:

slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))

实际上应该是:

slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2,3]))

您正在第一个轴上减少,但渐变是基于 Alpha 值所示的图像,因此您必须在轴 [1,2,3] 上减少。

代码中的另一个错误是生成器损失是:

gen_loss = d_loss_real - d_loss_fake

对于梯度计算来说,这没有什么区别,因为生成器的参数仅包含在 d_loss_fake 中。然而,对于发电机损失的值(value)来说,这造成了世界上的所有差异,也是其波动如此之大的原因。

最终,您应该查看您关心的实际性能指标,以确定 GAN 的质量,例如起始分数或 Fréchet 起始距离 (FID),因为鉴别器和生成器的损失只是轻微的描述性的。

关于python - WGAN-GP 列车损失较大,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53413706/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com