2 回答
TA贡献1843条经验 获得超7个赞
import sklearn
import pandas as pd
def tree_to_df(reg_tree, feature_names):
tree_ = reg_tree.tree_
feature_name = [
feature_names[i] if i != sklearn.tree._tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
def recurse(node, row, ret):
if tree_.feature[node] != sklearn.tree._tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
# Add rule to row and search left branch
row[-1].append(name + " <= " + str(threshold))
recurse(tree_.children_left[node], row, ret)
# Add rule to row and search right branch
row[-1].append(name + " > " + str(threshold))
recurse(tree_.children_right[node], row, ret)
else:
# Add output rules and start a new row
label = tree_.value[node]
ret.append("return " + str(label[0][0]))
row.append([])
# Initialize
rules = [[]]
vals = []
# Call recursive function with initial values
recurse(0, rules, vals)
# Convert to table and output
df = pd.DataFrame(rules).dropna(how='all')
df['Return'] = pd.Series(vals)
return df
这将返回一个 pandas 数据框:
0 1 2 3 Return
0 feature <= 20750.0 feature <= 7000.0 feature <= 1000.0 feature <= 300.0 return 1000.0
1 feature > 300.0 None None None return 3000.0
2 feature > 1000.0 feature <= 2500.0 None None return 5000.0
3 feature > 2500.0 feature <= 4250.0 None None return 8000.0
4 feature > 4250.0 feature <= 5500.0 None None return 6500.0
5 feature > 5500.0 None None None return 7000.0
6 feature > 7000.0 feature <= 13000.0 feature <= 8750.0 None return 15000.0
7 feature > 8750.0 feature <= 10750.0 None None return 20000.0
8 feature > 10750.0 None None None return 21000.0
9 feature > 13000.0 feature <= 16000.0 feature <= 14750.0 None return 25000.0
10 feature > 14750.0 None None None return 27000.0
11 feature > 16000.0 None None None return 30000.0
12 feature > 20750.0 feature <= 27500.0 None None return 52000.0
13 feature > 27500.0 None None None return 80000.0
TA贡献1852条经验 获得超7个赞
如果您正在处理分类决策树,您可以尝试一下
import pandas as pd
text="""
|--- Age <= 0.63
| |--- EstimatedSalary <= 0.61
| | |--- Age <= -0.16
| | | |--- class: 0
| | |--- Age > -0.16
| | | |--- EstimatedSalary <= -0.06
| | | | |--- class: 0
| | | |--- EstimatedSalary > -0.06
| | | | |--- EstimatedSalary <= 0.40
| | | | | |--- EstimatedSalary <= 0.03
| | | | | | |--- class: 1
"""
def tree_parser(text):
lines=text.splitlines()
max_levels=max([l.count('|') for l in lines])
result={}
for i in range(0,max_levels+1):
result['Column'+str(i)]=[]
for line in lines:
level=line.count('|')
currvalue=result.get('Column'+str(level),[])
currvalue.append(line.replace('|','').replace('-',''))
result['Column'+str(level)]=currvalue
for i in range(0, max_levels + 1):
if i>level and line.find('class')!=-1:
result['Column' + str(i)].append(None)
if i<level:
parent_value=result.get('Column' + str(i),[])
if len(parent_value)!=len(currvalue):
parent_value.append(parent_value[len(parent_value)-1])
return result
result=tree_parser(text)
df=pd.DataFrame(result)
df=df.drop(columns=['Column0'])
df.to_csv('treeout1.csv',index=False)
添加回答
举报