为了账号安全,请及时绑定邮箱和手机立即绑定

tf.keras.losses 中的“减少”参数

tf.keras.losses 中的“减少”参数

回首忆惘然 2023-07-05 15:44:19
根据文档,该Reduction参数有 3 个值 - SUM_OVER_BATCH_SIZE、SUM和NONE。y_true = [[0., 2.], [0., 0.]]y_pred = [[3., 1.], [2., 5.]]mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)mae(y_true, y_pred).numpy()> 5.5mae = tf.keras.losses.MeanAbsoluteError()mae(y_true, y_pred).numpy()> 2.75经过各种试验后我可以推断出的计算结果是:-当REDUCTION = SUM,Loss = Sum over all samples {(Sum of differences between y_pred and y_target vector of each sample / No of element in y_target of the sample )} = { (abs(3-0) + abs(1-2))/2 } + { (abs(2-0) + abs(5-0))/2 } = {4/2} + {7/2} = 5.5.当REDUCTION = SUM_OVER_BATCH_SIZE,Loss = [Sum over all samples {(Sum of differences between y_pred and y_target vector of each sample / No of element in y_target of the sample )}] / Batch_size or No of Samples  = [ { (abs(3-0)} + abs(1-2))/2 } + { (abs(2-0) + abs(5-0))/2 } ]/2 = [ {4/2} + {7/2} ]/2 = [5.5]/2 = 2.75.结果,SUM_OVER_BATCH_SIZE无非是SUM/batch_size。那么,为什么要调用它呢SUM_OVER_BATCH_SIZE?实际上是SUM将整个批次的损失相加,同时SUM_OVER_BATCH_SIZE计算该批次的平均损失。SUM_OVER_BATCH_SIZE我关于和的运作的假设是否SUM正确?
查看完整描述

1 回答

?
犯罪嫌疑人X

TA贡献2080条经验 获得超4个赞

据我了解,您的假设是正确的。


如果您检查 github [keras/losses_utils.py][1] 第 260-269 行,您将看到它确实按预期执行。 SUM将总结批量维度中的损失,并SUM_OVER_BATCH_SIZE除以总SUM损失数(批量大小)。


def reduce_weighted_loss(weighted_losses,

                     reduction=ReductionV2.SUM_OVER_BATCH_SIZE):

  if reduction == ReductionV2.NONE:

     loss = weighted_losses

  else:

     loss = tf.reduce_sum(weighted_losses)

     if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:

        loss = _safe_mean(loss, _num_elements(weighted_losses))

  return loss

您只需添加一对损失为零的输出即可对前面的示例进行简单检查。


y_true = [[0., 2.], [0., 0.],[1.,1.]]

y_pred = [[3., 1.], [2., 5.],[1.,1.]]


mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)

mae(y_true, y_pred).numpy()

> 5.5


mae = tf.keras.losses.MeanAbsoluteError()

mae(y_true, y_pred).numpy()

> 1.8333

所以,你的假设是正确的。[1]:https://github.com/keras-team/keras/blob/v2.7.0/keras/utils/losses_utils.py#L25-L84


查看完整回答
反对 回复 2023-07-05
  • 1 回答
  • 0 关注
  • 106 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信