为了账号安全,请及时绑定邮箱和手机立即绑定

在python中从xgboost中提取决策规则

在python中从xgboost中提取决策规则

慕仙森 2021-10-26 18:09:18
我想在 python 中为我即将推出的模型使用 xgboost。然而,由于我们的生产系统在 SAS 中,我试图从 xgboost 中提取决策规则,然后编写 SAS 评分代码以在 SAS 环境中实现该模型。上面两个链接对xgboost部署特别是Shiutang-Li给出的代码有很大帮助。但是,我的预测分数并不完全匹配。以下是我迄今为止尝试过的代码:import numpy as npimport pandas as pdimport xgboost as xgbfrom sklearn.grid_search import GridSearchCV%matplotlib inlineimport graphvizfrom graphviz import Digraph#Read the sample iris data:iris =pd.read_csv("C:\\Users\\XXXX\\Downloads\\Iris.csv")#Create dependent variable:iris.loc[iris["class"] != 2,"class"] = 0iris.loc[iris["class"] == 2,"class"] = 1#Select independent and dependent variable:X = iris[["sepal_length","sepal_width","petal_length","petal_width"]]Y = iris["class"]xgdmat = xgb.DMatrix(X, Y) # Create our DMatrix to make XGBoost more efficient#Build the sample xgboost Model:our_params = {'eta': 0.1, 'seed':0, 'subsample': 0.8, 'colsample_bytree': 0.8,              'objective': 'binary:logistic', 'max_depth':3, 'min_child_weight':1} Base_Model = xgb.train(our_params, xgdmat, num_boost_round = 10)#Below code reads the dump file created by xgboost and writes a scoring code in SAS:import redef string_parser(s):    if len(re.findall(r":leaf=", s)) == 0:        out  = re.findall(r"[\w.-]+", s)        tabs = re.findall(r"[\t]+", s)        if (out[4] == out[8]):            missing_value_handling = (" or missing(" + out[1] + ")")        else:            missing_value_handling = ""        if len(tabs) > 0:            return (re.findall(r"[\t]+", s)[0].replace('\t', '    ') +                     '        if state = ' + out[0] + ' then do;\n' +                    re.findall(r"[\t]+", s)[0].replace('\t', '    ') +                    '            if ' + out[1] + ' < ' + out[2] + missing_value_handling +所以基本上,我想要做的是,将节点号保存在变量“状态”中,并相应地访问叶节点(我从上面链接中提到的 Shiutang-Li 的文章中了解到)。
查看完整描述

3 回答

?
蓝山帝景

TA贡献1843条经验 获得超7个赞

我在获得匹配分数方面有类似的经验。
我的理解是,除非您修复ntree_limit选项以匹配n_estimators您在模型拟合期间使用的选项,否则评分可能会提前停止。

df['score']= xgclfpkl.predict(df[xg_features], ntree_limit=500)

开始使用后ntree_limit,我开始获得匹配的分数。


查看完整回答
反对 回复 2021-10-26
?
Smart猫小萌

TA贡献1911条经验 获得超7个赞

我有类似的经验,需要将 xgboost 评分代码从 R 提取到 SAS。

最初,我遇到了与您在这里相同的问题,即在较小的树中,R 和 SAS 的分数没有太大差异,一旦树的数量增加到 100 或更多,我开始观察差异.

我做了三件事来缩小差异:

  1. 确保丢失的组朝着正确的方向前进,您需要明确。否则 SAS 会将缺失值视为所有数字中的最小值。规则应该类似于 SAS 中的以下内容。

if sepal_width > 2.95000005 or missing(sepal_width) then state = 1;else state = 2;
或者
if sepal_width <= 2.95000005 and ~missing(sepal_width) then state = 1;else state = 2;

  1. 我使用了一个叫做 R 包float来使分数有更多的小数位。 as.numeric(float::fl(Quality))

  2. 确保 SAS 数据与您在 Python 中训练的数据具有相同的形状。

希望以上有帮助。


查看完整回答
反对 回复 2021-10-26
?
神不在的星期二

TA贡献1963条经验 获得超6个赞

几点——


首先,正则表达式叶返回值匹配并没有捕捉到垃圾堆里的“E-小数”科学记数法(默认)。显式示例(第二个是正确的修改!)-


s = '3:leaf=9.95066429e-09'

out = re.findall(r"[\d.-]+", s)

out2 = re.findall(r"-?[\d.]+(?:e-?\d+)?", s)

out2,out

(易于修复但不易发现,因为我的模型中只有一片叶子受到影响!)


其次,问题是关于二进制的,但在多类目标中,转储中的每个类都有单独的树,因此您T*C总共有树,其中T是提升轮C数,是类数。对于类c(在 {0,1,...,C-1} 中),您需要评估(并求和)树i*C +c的i = 0,...,T-1. 然后将其 softmax 以匹配来自 xgb 的预测。


查看完整回答
反对 回复 2021-10-26
  • 3 回答
  • 0 关注
  • 336 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号