为了账号安全,请及时绑定邮箱和手机立即绑定

在 lightgbm 中实现自定义 huber loss

在 lightgbm 中实现自定义 huber loss

MM们 2022-01-18 16:36:29
我正在尝试实现 Huber 损失,以便为 lightgbm 中的 MAPE 损失进行定制。下面是我的代码。但是,当我尝试运行它时,所有预测都为零。代码有什么问题?似乎一些缩放可能对学习有所帮助,但我在互联网上没有看到任何关于如何在自定义损失中应用它的指南。你能帮我解决这个问题吗?def my_loss(preds, dtrain):   y_true = dtrain.get_label()   d = (preds - y_true)   h = 1  #h is delta in the graphic   scale = 1 + (d / h) ** 2   scale_sqrt = np.sqrt(scale)   grad = d / scale_sqrt    hess = 1 / scale / scale_sqrt    hess = np.ones(len(preds))return grad, hessmetrics = []for i in my_cv:   X_train = X.loc[i[0],:]   y_train = y.loc[i[0]]   X_test = X.loc[i[1],:]   y_test = y.loc[i[1]]   dtrain = xgb.Dataset(X_train, label=y_train, free_raw_data =False)   params = {'max_depth': 10, 'learning_rate':0.05,'objective':None,         'num_leaves':150, 'min_child_samples':5, 'nround':100,         'monotone_constraints':lst_mon}   mm = xgb.train(params, dtrain, fobj = my_loss)   y_pred = mm.predict(X_train)
查看完整描述

3 回答

?
FFIVE

TA贡献1797条经验 获得超6个赞

正确的功能:


def my_loss(preds, dtrain):


   y_true = dtrain.get_label()

   d = (preds - y_true)

   h = 1  #h is delta in the graphic

   scale = 1 + (d / h) ** 2

   scale_sqrt = np.sqrt(scale)

   grad = d / scale_sqrt 

   hess = 1 / scale / scale_sqrt 


   return grad, hess

移除 hess = np.ones(len(preds))


查看完整回答
反对 回复 2022-01-18
?
HUX布斯

TA贡献1876条经验 获得超6个赞

可能是强制的效果monotone_constraints。仅当您获得可接受的结果并希望改进时才应设置它们。经过对数据和结果的深入分析。

此外(可能只是将代码复制到 SO 期间出现的错误),在您的损失函数中,由于hess = np.ones(len(preds)).


查看完整回答
反对 回复 2022-01-18
?
哈士奇WWW

TA贡献1799条经验 获得超6个赞

Huber损失定义

在此处输入图像描述

您实现的损失是它的平滑近似,即 Pseudo-Huber 损失: 在此处输入图像描述

这种损失的问题在于它的二阶导数太接近于零。为了加快他们的算法,lightgbm 使用牛顿方法的近似来找到最佳叶值:

y = - L' / L''

(有关详细信息,请参阅此博文)。

即,他们找到具有相同梯度和二阶导数的抛物线将达到最小值的点。如果损失函数是二次的,这给了我们精确的最优值。然而,对于 Pseudo-Huber 损失,牛顿的方法到处发散:

|- L'(a) / L''(a)| = (1 + (a/delta)**2) * |a| > |一个|,

所以你得到的近似值总是比你开始的值更远。

当您使用np.ones粗麻布时,您会得到 -L'(a) 作为零的估计值,它也不会收敛到零。

要使用 Pseudo-Huber 损失正确实现梯度提升,您必须放弃使用 hessian 并使用正常梯度下降来找到最佳叶值。你不能在 lightgbm 的自定义损失中做到这一点,但 lightgbm 有一个内置的 huber 损失,所以你可以使用它。


查看完整回答
反对 回复 2022-01-18
  • 3 回答
  • 0 关注
  • 323 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信