为了账号安全,请及时绑定邮箱和手机立即绑定

如何从scikit-learn决策树中提取决策规则?

如何从scikit-learn决策树中提取决策规则?

慕沐林林 2019-07-25 20:14:47
如何从scikit-learn决策树中提取决策规则?我可以从决策树中的受过训练的树中提取基础决策规则(或“决策路径”)作为文本列表吗?就像是:if A>0.4 then if B<0.2 then if C>0.8 then class='X'谢谢你的帮助。
查看完整描述

3 回答

?
慕容3067478

TA贡献1773条经验 获得超3个赞

我相信这个答案比其他答案更正确:

from sklearn.tree import _treedef tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))
    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])
    recurse(0, 1)

这将打印出有效的Python函数。以下是尝试返回其输入的树的示例输出,该数字介于0和10之间。

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

以下是我在其他答案中看到的一些绊脚石:

  1. 使用tree_.threshold == -2来决定一个节点是否为叶是不是一个好主意。如果它是一个阈值为-2的真实决策节点怎么办?相反,你应该看看tree.featuretree.children_*

  2. 该行features = [feature_names[i] for i in tree_.feature]与我的sklearn版本崩溃,因为某些值为tree.tree_.feature-2(特别是对于叶节点)。

  3. 递归函数中不需要多个if语句,只需一个就可以了。


查看完整回答
反对 回复 2019-07-25
?
翻翻过去那场雪

TA贡献2065条经验 获得超14个赞

我修改了Zelazny7提交的代码来打印一些伪代码:

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value
        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])
        recurse(left, right, threshold, features, 0)

如果您get_code(dt, df.columns)使用相同的示例,您将获得:

if ( col1 <= 0.5 ) {

return [[ 1.  0.]]

} else {

if ( col2 <= 4.5 ) {

return [[ 0.  1.]]

} else {

if ( col1 <= 2.5 ) {

return [[ 1.  0.]]

} else {

return [[ 0.  1.]]

}

}

}


查看完整回答
反对 回复 2019-07-25
  • 3 回答
  • 0 关注
  • 3666 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信