这是WGAN-GP的损失函数gen_sample = model.generator(input_gen)disc_real = model.discriminator(real_image, reuse=False)disc_fake = model.discriminator(gen_sample, reuse=True)disc_concat = tf.concat([disc_real, disc_fake], axis=0)# Gradient penaltyalpha = tf.random_uniform( shape=[BATCH_SIZE, 1, 1, 1], minval=0., maxval=1.)differences = gen_sample - real_imageinterpolates = real_image + (alpha * differences)gradients = tf.gradients(model.discriminator(interpolates, reuse=True), [interpolates])[0] # why [0]slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))gradient_penalty = tf.reduce_mean((slopes-1.)**2)d_loss_real = tf.reduce_mean(disc_real)d_loss_fake = tf.reduce_mean(disc_fake)disc_loss = -(d_loss_real - d_loss_fake) + LAMBDA * gradient_penaltygen_loss = - d_loss_fake发电机损耗震荡,值这么大。我的问题是:发电机损耗是正常的还是异常的?
1 回答

喵喔喔
TA贡献1735条经验 获得超5个赞
需要注意的一件事是您的梯度惩罚计算是错误的。以下行:
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
实际上应该是:
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2,3]))
您在第一个轴上减少,但渐变基于 alpha 值显示的图像,因此您必须在轴上减少[1,2,3]
。
代码中的另一个错误是生成器损失是:
gen_loss = d_loss_real - d_loss_fake
对于梯度计算,这没有区别,因为生成器的参数仅包含在 d_loss_fake 中。然而,对于发电机损失的价值,这在世界上造成了很大的不同,这也是为什么会如此震荡的原因。
归根结底,您应该查看您关心的实际性能指标,以确定 GAN 的质量,例如初始分数或 Fréchet 初始距离 (FID),因为鉴别器和生成器的损失仅具有轻微的描述性。
添加回答
举报
0/150
提交
取消