1 回答
TA贡献2080条经验 获得超4个赞
据我了解,您的假设是正确的。
如果您检查 github [keras/losses_utils.py][1] 第 260-269 行,您将看到它确实按预期执行。 SUM将总结批量维度中的损失,并SUM_OVER_BATCH_SIZE除以总SUM损失数(批量大小)。
def reduce_weighted_loss(weighted_losses,
reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
if reduction == ReductionV2.NONE:
loss = weighted_losses
else:
loss = tf.reduce_sum(weighted_losses)
if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
loss = _safe_mean(loss, _num_elements(weighted_losses))
return loss
您只需添加一对损失为零的输出即可对前面的示例进行简单检查。
y_true = [[0., 2.], [0., 0.],[1.,1.]]
y_pred = [[3., 1.], [2., 5.],[1.,1.]]
mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)
mae(y_true, y_pred).numpy()
> 5.5
mae = tf.keras.losses.MeanAbsoluteError()
mae(y_true, y_pred).numpy()
> 1.8333
所以,你的假设是正确的。[1]:https://github.com/keras-team/keras/blob/v2.7.0/keras/utils/losses_utils.py#L25-L84
添加回答
举报