为了账号安全,请及时绑定邮箱和手机立即绑定

将回归树输出转换为 pandas 表

将回归树输出转换为 pandas 表

月关宝盒 2023-07-18 13:50:28
这段代码适合 python 中的回归树。我想将此基于文本的输出转换为表格格式。import pandas as pdimport numpy as npfrom sklearn.tree import DecisionTreeRegressorfrom sklearn import treedataset = np.array( [['Asset Flip', 100, 1000], ['Text Based', 500, 3000], ['Visual Novel', 1500, 5000], ['2D Pixel Art', 3500, 8000], ['2D Vector Art', 5000, 6500], ['Strategy', 6000, 7000], ['First Person Shooter', 8000, 15000], ['Simulator', 9500, 20000], ['Racing', 12000, 21000], ['RPG', 14000, 25000], ['Sandbox', 15500, 27000], ['Open-World', 16500, 30000], ['MMOFPS', 25000, 52000], ['MMORPG', 30000, 80000] ]) X = dataset[:, 1:2].astype(int)y = dataset[:, 2].astype(int)  regressor = DecisionTreeRegressor(random_state = 0) regressor.fit(X, y) text_rule = tree.export_text(regressor )print(text_rule)我得到的输出是这样的print(text_rule)|--- feature_0 <= 20750.00|   |--- feature_0 <= 7000.00|   |   |--- feature_0 <= 1000.00|   |   |   |--- feature_0 <= 300.00|   |   |   |   |--- value: [1000.00]|   |   |   |--- feature_0 >  300.00|   |   |   |   |--- value: [3000.00]|   |   |--- feature_0 >  1000.00|   |   |   |--- feature_0 <= 2500.00|   |   |   |   |--- value: [5000.00]|   |   |   |--- feature_0 >  2500.00|   |   |   |   |--- feature_0 <= 4250.00|   |   |   |   |   |--- value: [8000.00]|   |   |   |   |--- feature_0 >  4250.00|   |   |   |   |   |--- feature_0 <= 5500.00|   |   |   |   |   |   |--- value: [6500.00]|   |   |   |   |   |--- feature_0 >  5500.00|   |   |   |   |   |   |--- value: [7000.00]|   |--- feature_0 >  7000.00|   |   |--- feature_0 <= 13000.00|   |   |   |--- feature_0 <= 8750.00|   |   |   |   |--- value: [15000.00]|   |   |   |--- feature_0 >  8750.00我想在 pandas 表中转换此规则,类似于以下形式。这个怎么做 ?规则的情节版本是这样的(供参考)。请注意,在表中我显示了规则的最左边部分。
查看完整描述

2 回答

?
蓝山帝景

TA贡献1843条经验 获得超7个赞

import sklearn

import pandas as pd


def tree_to_df(reg_tree, feature_names):

    tree_ = reg_tree.tree_

    feature_name = [

        feature_names[i] if i != sklearn.tree._tree.TREE_UNDEFINED else "undefined!"

        for i in tree_.feature

    ]

    

    def recurse(node, row, ret):

        if tree_.feature[node] != sklearn.tree._tree.TREE_UNDEFINED:

            name = feature_name[node]

            threshold = tree_.threshold[node]

            # Add rule to row and search left branch

            row[-1].append(name + " <= " +  str(threshold))

            recurse(tree_.children_left[node], row, ret)

            # Add rule to row and search right branch

            row[-1].append(name + " > " +  str(threshold))

            recurse(tree_.children_right[node], row, ret)

        else:

            # Add output rules and start a new row

            label = tree_.value[node]

            ret.append("return " + str(label[0][0]))

            row.append([])

    

    # Initialize

    rules = [[]]

    vals = []

    

    # Call recursive function with initial values

    recurse(0, rules, vals)

    

    # Convert to table and output

    df = pd.DataFrame(rules).dropna(how='all')

    df['Return'] = pd.Series(vals)

    return df

这将返回一个 pandas 数据框:


                     0                   1                   2                 3          Return

0   feature <= 20750.0   feature <= 7000.0   feature <= 1000.0  feature <= 300.0   return 1000.0

1      feature > 300.0                None                None              None   return 3000.0

2     feature > 1000.0   feature <= 2500.0                None              None   return 5000.0

3     feature > 2500.0   feature <= 4250.0                None              None   return 8000.0

4     feature > 4250.0   feature <= 5500.0                None              None   return 6500.0

5     feature > 5500.0                None                None              None   return 7000.0

6     feature > 7000.0  feature <= 13000.0   feature <= 8750.0              None  return 15000.0

7     feature > 8750.0  feature <= 10750.0                None              None  return 20000.0

8    feature > 10750.0                None                None              None  return 21000.0

9    feature > 13000.0  feature <= 16000.0  feature <= 14750.0              None  return 25000.0

10   feature > 14750.0                None                None              None  return 27000.0

11   feature > 16000.0                None                None              None  return 30000.0

12   feature > 20750.0  feature <= 27500.0                None              None  return 52000.0

13   feature > 27500.0                None                None              None  return 80000.0



查看完整回答
反对 回复 2023-07-18
?
慕姐4208626

TA贡献1852条经验 获得超7个赞

如果您正在处理分类决策树,您可以尝试一下


import pandas as pd

text="""

|--- Age <= 0.63

|   |--- EstimatedSalary <= 0.61

|   |   |--- Age <= -0.16

|   |   |   |--- class: 0

|   |   |--- Age >  -0.16

|   |   |   |--- EstimatedSalary <= -0.06

|   |   |   |   |--- class: 0

|   |   |   |--- EstimatedSalary >  -0.06

|   |   |   |   |--- EstimatedSalary <= 0.40

|   |   |   |   |   |--- EstimatedSalary <= 0.03

|   |   |   |   |   |   |--- class: 1

"""



def tree_parser(text):

    lines=text.splitlines()

    max_levels=max([l.count('|') for l in lines])

    result={}


    for i in range(0,max_levels+1):

        result['Column'+str(i)]=[]


    for line in lines:

        level=line.count('|')

        currvalue=result.get('Column'+str(level),[])

        currvalue.append(line.replace('|','').replace('-',''))

        result['Column'+str(level)]=currvalue

        for i in range(0, max_levels + 1):

            if i>level and line.find('class')!=-1:

                result['Column' + str(i)].append(None)

            if i<level:

                parent_value=result.get('Column' + str(i),[])

                if len(parent_value)!=len(currvalue):

                    parent_value.append(parent_value[len(parent_value)-1])

    return result



result=tree_parser(text)

df=pd.DataFrame(result)

df=df.drop(columns=['Column0'])

df.to_csv('treeout1.csv',index=False)


查看完整回答
反对 回复 2023-07-18
  • 2 回答
  • 0 关注
  • 98 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信