为了账号安全,请及时绑定邮箱和手机立即绑定

Keras 自定义损失函数 huber

Keras 自定义损失函数 huber

眼眸繁星 2023-07-05 16:31:58
我使用 Keras 后端函数编写了 huber loss,效果很好:def huber_loss(y_true, y_pred, clip_delta=1.0):    error = y_true - y_pred    cond  = K.abs(error) < clip_delta             squared_loss = 0.5 * K.square(error)    linear_loss  = clip_delta * (K.abs(error) - 0.5 * clip_delta)    return tf_where(cond, squared_loss, linear_loss)但我需要一个更复杂的损失函数:如果error <= A,使用squared_loss如果A <= error < B,使用线性损失如果error >= B,使用sqrt_loss我是这样写的:def best_loss(y_true, y_pred, A, B):    error = K.abs(y_true - y_pred)    cond  = error <= A    cond2 = tf_logical_and(A < error, error <= B)    squared_loss = 0.5 * K.square(error)    linear_loss  = A * (error - 0.5 * A)    sqrt_loss = A * np.sqrt(B) * K.sqrt(error) - 0.5 * A**2    return tf_where(cond, squared_loss, tf_where(cond2, linear_loss, sqrt_loss))但是不行,用这个损失函数的模型不收敛,有什么bug?
查看完整描述

1 回答

?
呼唤远方

TA贡献1856条经验 获得超11个赞

我喜欢通过使用像 Desmos 这样的程序绘制自定义函数的图形来调试它们。我使用您的实现绘制了 Huber 损失图,它看起来应该是这样的。

当我尝试绘制第二个函数的图表时,它看起来也像是一个有效的损失函数。唯一的问题是当 B 小于 A 时。如果 B 的值大于 A,那么损失函数应该不会有问题。如果这不是问题,那么您可以尝试在目标和输出之间切换减法,因为我不熟悉张量流如何处理微分,但顺序会影响梯度的方向。


查看完整回答
反对 回复 2023-07-05
  • 1 回答
  • 0 关注
  • 174 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信