为了账号安全,请及时绑定邮箱和手机立即绑定

铰链损失函数梯度 wrt 输入预测

铰链损失函数梯度 wrt 输入预测

PIPIONE 2021-08-14 13:33:01
对于作业,我必须同时实现铰链损失及其偏导数计算函数。我得到了 Hinge 损失函数本身,但我很难理解如何计算其偏导数 wrt 预测输入。我尝试了不同的方法,但都没有奏效。任何帮助,提示,建议将不胜感激!这是铰链损失函数本身的解析表达式:这是我的铰链损失函数实现:def hinge_forward(target_pred, target_true):    """Compute the value of Hinge loss         for a given prediction and the ground truth    # Arguments        target_pred: predictions - np.array of size `(n_objects,)`        target_true: ground truth - np.array of size `(n_objects,)`    # Output        the value of Hinge loss         for a given prediction and the ground truth        scalar    """    output = np.sum((np.maximum(0, 1 - target_pred * target_true)) / target_pred.size)    return output现在我需要计算这个梯度:这是我尝试的铰链损失梯度计算:def hinge_grad_input(target_pred, target_true):    """Compute the partial derivative         of Hinge loss with respect to its input    # Arguments        target_pred: predictions - np.array of size `(n_objects,)`        target_true: ground truth - np.array of size `(n_objects,)`    # Output        the partial derivative         of Hinge loss with respect to its input        np.array of size `(n_objects,)`    """# ----------------#     try 1# ----------------#     hinge_result = hinge_forward(target_pred, target_true)#     if hinge_result == 0:#         grad_input = 0#     else:#         hinge = np.maximum(0, 1 - target_pred * target_true)#         grad_input = np.zeros_like(hinge)#         grad_input[hinge > 0] = 1#         grad_input = np.sum(np.where(hinge > 0))# ----------------#     try 2# ----------------#     hinge = np.maximum(0, 1 - target_pred * target_true)#     grad_input = np.zeros_like(hinge)#     grad_input[hinge > 0] = 1# ----------------#     try 3# ----------------    hinge_result = hinge_forward(target_pred, target_true)    if hinge_result == 0:        grad_input = 0    else:        loss = np.maximum(0, 1 - target_pred * target_true)        grad_input = np.zeros_like(loss)        grad_input[loss > 0] = 1        grad_input = np.sum(grad_input) * target_pred    return grad_input
查看完整描述

1 回答

  • 1 回答
  • 0 关注
  • 421 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号