为了账号安全,请及时绑定邮箱和手机立即绑定

机器学习实战-02-决策树

标签:
数据结构

1、决策树简单介绍

  上一篇的kNN算法虽然可以完成很多分类任务,但它最大的缺点是无法给出数据的内在含义,而决策树的主要优势就在于数据形式非常容易理解。决策树算法能够读取数据集合,决策树的一个重要任务是为了数据所蕴含的知识信息。因此,决策树可以使用不熟悉的数据集合,并从中提取一系列规则,在这些机器根据数据集创建规则就是机器学习的过程。

webp

2. 决策树.png

信息增益公式计算:

(1)计算数据集D的经验熵H(D)

webp

数据集D的熵

(2)计算特征A对数据集D的条件熵H(D|A)

webp

特征A对数据集D的条件熵H(D|A)

(3)计算信息增益

webp

特征A对数据集D划分后的信息增益

举个简单计算例子。

webp

image.png

webp

image.png

其他3个属性计算类似,选取信息增益最大特征的划分。

2、Python3.6 代码实现

webp

决策树的一般流程

全部代码:

# -*- coding: UTF-8 -*-from math import logimport operatorimport matplotlib.pyplot as pltimport pickledef creatDataSet():
    """
        Function:
            创建测试数据集
        Parameters:
            无
        Returns:
            dataSet - 数据集
            labels - 分类属性标签
        Modify:
            2018-08-02
        """
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']    return dataSet, labelsdef calcShannonEnt(dataSet):
    """
        Function:
            计算给定数据集经验熵
        Parameters:
            dataSet - 数据集
        Returns:
            shannonEnt - 经验熵
        Modify:
            2018-08-02
        """
    numEntries = len(dataSet)
    labelCounts = {}    for featVec in dataSet:
        currentLabel = featVec[-1]        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)    return shannonEntdef splitDataSet(dataSet, axis, value):
    """
        Function:
            按照给定特征划分数据集
        Parameters:
            dataSet - 待划分的数据集
            axis - 划分数据集的特征
            value - 特征的取值
        Returns:
            retDataSet - 划分后的数据集
        Modify:
            2018-08-02
        """
    retDataSet = []    for featVec in dataSet:        if featVec[axis] == value:            # 以下两行代表去除该行的featVec[axis]元素
            reduceFeatVec = featVec[:axis]
            reduceFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reduceFeatVec)    return retDataSet# 选择最好的数据划分方式ID3def chooseBestFeatureToSplit(dataSet):
    """
        Function:
            选择最优特征划分方式(计算信息增益)
        Parameters:
            dataSet - 数据集
        Returns:
            bestFeature - 信息增益最大的特征的索引值
        Modify:
            2018-08-02
        """
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):        # 获取特征i的特征值列表
        featList = [example[i] for example in dataSet]        # 利用set集合元素唯一性的性质,得到特征i的取值
        uniqueVals = set(featList)
        newEntropy = 0.0
        # 计算第i特征划分信息增益
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        print('第%d个特征的增益为%.3f' % (i, infoGain))        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i    return bestFeaturedef majorityCnt(classList):
    """
        Function:
            多数表决的方法完成分类
        Parameters:
            classList - 类标签列表
        Returns:
            sortedClassCount[0][0] - 出现次数最多的类标签
        Modify:
            2018-08-02
        """
    classCount = {}    for vote in classList:        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)    return sortedClassCount[0][0]def createTree(dataSet, labels):
    """
        Function:
            创建决策树
        Parameters:
            dataSet - 数据集
            labels - 分类属性标签
        Returns:
            myTree - 决策树
        Modify:
            2018-08-02
        """
    classList = [example[-1] for example in dataSet]    # 判断所有类标签是否相同,相同则返回该类标签
    if (classList.count(classList[0]) == len(classList)):        return classList[0]    # 遍历完所有的特征属性,此时数据集的列为1,即只有类标签列
    if len(dataSet[0]) == 1:        return majorityCnt(classList)    # 选择最优特征
    bestFeature = chooseBestFeatureToSplit(dataSet)
    bestFeatureLabel = labels[bestFeature]    # 采用字典嵌套字典的方式,存储分类树信息
    myTree = {bestFeatureLabel: {}}    # 复制当前特征标签列表,防止改变原始列表的内容
    subLabels = labels[:]    del (subLabels[bestFeature])
    featValues = [example[bestFeature] for example in dataSet]
    uniqueVals = set(featValues)    for value in uniqueVals:
        myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeature, value), subLabels)    return myTreedef getNumLeafs(myTree):
    """
        Function:
            获取叶节点的数目
        Parameters:
            myTree - 决策树
        Returns:
            numLeafs - 叶节点的数目
        Modify:
            2018-08-04
        """
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])        else:
            numLeafs += 1
    return numLeafsdef getTreeDepth(myTree):
    """
        Function:
            获取树的层数
        Parameters:
            myTree - 决策树
        Returns:
            numLeafs - 树的层数
        Modify:
            2018-08-04
        """
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth    return maxDepth# 绘制带箭头的注释def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """
        Function:
            绘制带箭头的注释
        Parameters:
            nodeTxt - 结点名
            centerPt - 文本位置
            parentPt - 标注的箭头位置
            nodeType - 结点格式
        Returns:
            无
        Modify:
            2018-08-04
        """
    # 定义箭头格式
    arrow_args = dict(arrowstyle="<-")    # 绘制结点
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction', va="center",
                            ha="center", bbox=nodeType, arrowprops=arrow_args)def plotMidText(cntrPt, parentPt, txtString):
    """
        Function:
            计算父节点和子节点的中间位置,并在此处添加简单的文本标签信息
        Parameters:
            cntrPt、parentPt - 用于计算标注位置
            txtString - 标注的内容
        Returns:
            无
        Modify:
            2018-08-04
        """
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)# 计算宽与高def plotTree(myTree, parentPt, nodeTxt):
    """
        Function:
            绘制决策树
        Parameters:
            myTree - 字典决策树
            parentPt - 标注的内容
            nodeTxt - 结点名
        Returns:
            无
        Modify:
            2018-08-04
        """
    # 定义文本框和箭头格式
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")
    leafNode = dict(boxstyle="round4", fc="0.8")
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)    # 标记子节点属性值
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]    # 减少y偏移
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalDdef createPlot(inTree):
    """
        Function:
            绘树主函数
        Parameters:
            inTree - 字典决策树
        Returns:
            无
        Modify:
            2018-08-04
        """
    # 创建fig
    fig = plt.figure(1, facecolor='white')    # 清空fig
    fig.clf()    # 设置坐标轴数据
    axprops = dict(xticks=[], yticks=[])    # 去除坐标轴
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))    # 两个全局变量plotTree.xOff和plotTree.yOff追踪已经绘制的节点位置,
    # 以及放置下一个节点的恰当位置
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()def classify(inputTreee, featLabels, testVec):
    """
        Function:
            使用决策树分类
        Parameters:
            inputTree - 训练好的决策树信息
            featLabels - 标签列表
            testVec - 测试向量
        Returns:
            无
        Modify:
            2018-08-04
        """
    # 获取决策树结点
    firstStr = list(inputTreee.keys())[0]    # 下一个字典
    secondDict = inputTreee[firstStr]
    featIndex = featLabels.index(firstStr)    for key in secondDict.keys():        if testVec[featIndex] == key:            if type(secondDict[key]).__name__ == 'dict':
                classLable = classify(secondDict[key], featLabels, testVec)            else:
                classLable = secondDict[key]    return classLabledef storeTree(inputTree, filename):
    """
        Function:
            使用pickle模块存储决策树
        Parameters:
            inputTree - 已经生成的决策树
            filename - 决策树的存储文件名
        Returns:
            无
        Modify:
            2018-08-04
        """
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()def grabTree(filename):
    """
        Function:
            获取保存好的决策树
        Parameters:
            filename - 决策树的存储文件名
        Returns:
            无
        Modify:
            2018-08-04
        """
    fr = open(filename, 'rb')    return pickle.load(fr)def predictLensesType(filename):
    """
        Function:
            使用决策树预测隐形眼镜类型
        Parameters:
            filename - 隐形眼镜数据集文件名
        Returns:
            无
        Modify:
            2018-08-04
        """
    # 打开文本数据
    fr = open(filename)    # 将文本数据的每一个数据行按照tab键分割,并依次存入lenses
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]    # 创建并存入特征标签列表
    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']    # 根据继续文件得到的数据集和特征标签列表创建决策树
    lensesTree = createTree(lenses, lensesLabels)    return lensesTreeif __name__ == '__main__':
    dataSet, labels = creatDataSet()
    dataSetEnt = calcShannonEnt(dataSet)    # print(dataSetEnt)

    # retDataSet = splitDataSet(dataSet, 0, 0)
    # print(retDataSet)
    #
    # bestFeature = chooseBestFeatureToSplit(dataSet)
    # print('最优特征索引值:', bestFeature)

    # myTree = createTree(dataSet, labels)
    # print(myTree)

    # myTree = createTree(dataSet, labels)
    # print(myTree)
    # createPlot(myTree)

    # myTree = createTree(dataSet, labels)
    # classifyResult1 = classify(myTree, labels, [1, 0])
    # print(classifyResult1)
    # classifyResult2 = classify(myTree, labels, [1, 1])
    # print(classifyResult2)

    myTree = createTree(dataSet, labels)
    print(myTree)
    storeTree(myTree, 'classifierStorage.txt')
    impTree = grabTree('classifierStorage.txt')
    print('impTree:', impTree)
    classifyResult2 = classify(myTree, labels, [1, 1])
    print('[1, 1]的分类为:', classifyResult2)

(1)准备数据集比计算经验熵

webp

计算经验熵运行结果

(2)划分数据集

计算信息增益选择最优划分特征,根据最优特征进行划分数据集。

webp

划分数据集运行结果

(3)递归构建决策树

  构建决策树的工作原理:首先得到原始数据集,然后基于最好的属性划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将向下传递到树分支的下一个结点,在该结点上,我们可以再次划分数据。因此,我们可以采用递归的方法处理数据集,完成决策树构造。
  递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果所有的实例具有相同的分类,则得到一个叶子结点或者终止块。
  当遍历完所有的特征属性,但是某个或多个分支下实例类标签仍然不唯一,此时我们需要确定出如何定义该叶子结点,在这种情况下,通过会采取多数表决的方法选取分支下实例中类标签种类最多的分类作为该叶子结点的分类。

webp

递归构建决策树运行结果

(4)可视化决策树

  用字典的形式表示决策树非常不易于理解,决策树的主要优点就是直观易于理解,如果不能将其直观显示出来,就无法发挥其优势。本节使用 Matplotlib 库编写代码绘制决策树。Matplotlib 提供了一个非常有用的注解工具annotations,它可以在数据图形上添加文本注解。

webp

可视化运行结果

(5)测试和存储分类器

webp

使用决策树测试分类运行结果

webp

pickle模块存储决策树及读取决策树分类运行结果

3、示例:使用决策树预测隐形眼镜类型

webp

步骤1

webp

步骤2

webp

运行结果1

webp

可视化结果

4、应用scikit-learn实现预测隐形眼镜类型

  如何利用Graphviz可视化决策树,参见我的这篇简书:https://www.jianshu.com/p/dd552f780a40

  因为在fit()函数不能接收string类型的数据,所以在使用fit()函数之前,我们需要对数据集进行编码,可以使用两种方法:

这里对string类型的数据序列化使用的方法是:原始数据->字典->pandas数据

全部代码:

# -*- coding: UTF-8 -*-from sklearn import treeimport pandas as pdfrom sklearn.preprocessing import LabelEncoderfrom sklearn.externals.six import StringIOimport pydotplusif __name__ == '__main__':    # 加载数据文件
    with open('D:/PycharmProjects/Machine/machinelearninginaction/Ch03/lenses.txt') as fr:
        lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lensesTarget = []    for each in lenses:
        lensesTarget.append(each[-1])

    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lensesList = []
    lensesDict = {}    for eachLabel in lensesLabels:        for each in lenses:
            lensesList.append(each[lensesLabels.index(eachLabel)])
        lensesDict[eachLabel] = lensesList
        lensesList = []
    print(lensesDict)
    lensesPd = pd.DataFrame(lensesDict)    # 创建LabelEncoder()对象,用于序列化
    le = LabelEncoder()    for col in lensesPd.columns:
        lensesPd[col] = le.fit_transform(lensesPd[col])
    print(lensesPd)    # 创建DecisionTreeClassifier()类
    clf = tree.DecisionTreeClassifier(max_depth=4)    # tolist()将数组或者矩阵转换成列表
    # 使用数据,构建决策树
    clf = clf.fit(lensesPd.values.tolist(), lensesTarget)    # 可视化决策树
    dotData = StringIO()
    tree.export_graphviz(clf, out_file=dotData, feature_names=lensesPd.keys(),
                         class_names=clf.classes_, filled=True, rounded=True,
                         special_characters=True)
    graph = pydotplus.graph_from_dot_data(dotData.getvalue())
    graph.write_pdf('lensesTree.pdf')

    print(clf.predict([[1, 1, 1, 0]]))

webp

可视化决策树结果

webp

预测结果

5、小结

  决策树算法可能或出现的过度匹配(过拟合)的问题,当决策树的复杂度较大时,很可能会造成过拟合问题。此时,可以通过裁剪决策树的办法,降低决策树的复杂度,提高决策树的泛化能力。比如,如果决策树的某一叶子结点只能增加很少的信息,那么就可将该节点删掉,将其并入到相邻的结点中去,这样,降低了决策树的复杂度,消除过拟合问题。
  本篇使用的算法为ID3,但它无法直接处理数值型数据,尽管可以通过量化的方法将数值型数据转化为标称型数值,但如果存在太多的特征划分,ID3算法仍然会面临其他问题。



作者:nobodyyang
链接:https://www.jianshu.com/p/4676b36d1cd9


点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消