为了账号安全,请及时绑定邮箱和手机立即绑定

WGAN-GP 训练损失大

WGAN-GP 训练损失大

慕哥6287543 2021-08-24 16:32:57
这是WGAN-GP的损失函数gen_sample = model.generator(input_gen)disc_real = model.discriminator(real_image, reuse=False)disc_fake = model.discriminator(gen_sample, reuse=True)disc_concat = tf.concat([disc_real, disc_fake], axis=0)# Gradient penaltyalpha = tf.random_uniform(    shape=[BATCH_SIZE, 1, 1, 1],    minval=0.,    maxval=1.)differences = gen_sample - real_imageinterpolates = real_image + (alpha * differences)gradients = tf.gradients(model.discriminator(interpolates, reuse=True), [interpolates])[0]    # why [0]slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))gradient_penalty = tf.reduce_mean((slopes-1.)**2)d_loss_real = tf.reduce_mean(disc_real)d_loss_fake = tf.reduce_mean(disc_fake)disc_loss = -(d_loss_real - d_loss_fake) + LAMBDA * gradient_penaltygen_loss = - d_loss_fake发电机损耗震荡,值这么大。我的问题是:发电机损耗是正常的还是异常的?
查看完整描述

1 回答

?
喵喔喔

TA贡献1735条经验 获得超5个赞

需要注意的一件事是您的梯度惩罚计算是错误的。以下行:

slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))

实际上应该是:

slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2,3]))

您在第一个轴上减少,但渐变基于 alpha 值显示的图像,因此您必须在轴上减少[1,2,3]

代码中的另一个错误是生成器损失是:

gen_loss = d_loss_real - d_loss_fake

对于梯度计算,这没有区别,因为生成器的参数仅包含在 d_loss_fake 中。然而,对于发电机损失的价值,这在世界上造成了很大的不同,这也是为什么会如此震荡的原因。

归根结底,您应该查看您关心的实际性能指标,以确定 GAN 的质量,例如初始分数或 Fréchet 初始距离 (FID),因为鉴别器和生成器的损失仅具有轻微的描述性。


查看完整回答
反对 回复 2021-08-24
  • 1 回答
  • 0 关注
  • 951 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号