如何从scikit-learn决策树中提取决策规则?我可以从决策树中的受过训练的树中提取基础决策规则(或“决策路径”)作为文本列表吗?就像是:if A>0.4 then if B<0.2 then if C>0.8 then class='X'谢谢你的帮助。
3 回答

慕容3067478
TA贡献1773条经验 获得超3个赞
我相信这个答案比其他答案更正确:
from sklearn.tree import _treedef tree_to_code(tree, feature_names): tree_ = tree.tree_ feature_name = [ feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature ] print "def tree({}):".format(", ".join(feature_names)) def recurse(node, depth): indent = " " * depth if tree_.feature[node] != _tree.TREE_UNDEFINED: name = feature_name[node] threshold = tree_.threshold[node] print "{}if {} <= {}:".format(indent, name, threshold) recurse(tree_.children_left[node], depth + 1) print "{}else: # if {} > {}".format(indent, name, threshold) recurse(tree_.children_right[node], depth + 1) else: print "{}return {}".format(indent, tree_.value[node]) recurse(0, 1)
这将打印出有效的Python函数。以下是尝试返回其输入的树的示例输出,该数字介于0和10之间。
def tree(f0): if f0 <= 6.0: if f0 <= 1.5: return [[ 0.]] else: # if f0 > 1.5 if f0 <= 4.5: if f0 <= 3.5: return [[ 3.]] else: # if f0 > 3.5 return [[ 4.]] else: # if f0 > 4.5 return [[ 5.]] else: # if f0 > 6.0 if f0 <= 8.5: if f0 <= 7.5: return [[ 7.]] else: # if f0 > 7.5 return [[ 8.]] else: # if f0 > 8.5 return [[ 9.]]
以下是我在其他答案中看到的一些绊脚石:
使用
tree_.threshold == -2
来决定一个节点是否为叶是不是一个好主意。如果它是一个阈值为-2的真实决策节点怎么办?相反,你应该看看tree.feature
或tree.children_*
。该行
features = [feature_names[i] for i in tree_.feature]
与我的sklearn版本崩溃,因为某些值为tree.tree_.feature
-2(特别是对于叶节点)。递归函数中不需要多个if语句,只需一个就可以了。

翻翻过去那场雪
TA贡献2065条经验 获得超14个赞
我修改了Zelazny7提交的代码来打印一些伪代码:
def get_code(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] value = tree.tree_.value def recurse(left, right, threshold, features, node): if (threshold[node] != -2): print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {" if left[node] != -1: recurse (left, right, threshold, features,left[node]) print "} else {" if right[node] != -1: recurse (left, right, threshold, features,right[node]) print "}" else: print "return " + str(value[node]) recurse(left, right, threshold, features, 0)
如果您get_code(dt, df.columns)
使用相同的示例,您将获得:
if ( col1 <= 0.5 ) {
return [[ 1. 0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0. 1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1. 0.]]
} else {
return [[ 0. 1.]]
}
}
}
添加回答
举报
0/150
提交
取消