[机器学习实战]决策树

来源:互联网 发布:淘宝卖家真实现状 编辑:IT博客网 时间:2019/07/20 17:24

    • 原理
    • 步骤分解
    • 完整代码

原理

决策树图示
通过提问的方式,根据不同的答案选择不同的分支, 完成不同的分类

步骤分解

1.遍历数据集, 循环计算提取每个特征的香农熵和信息增益, 选取信息增益最大的特征。 再递归计算剩余的特征顺序。 将特征排序。 并将分类结果序列化保存到磁盘当中

def chooseBestFeatureToSplit(dataSet):  # 选择最好的分类特征    """    :param dataSet: 原数据集    :return: 最好的划分特征的索引值    """    numFeatures = len(dataSet[0]) - 1   # 获取特征数    baseEntropy = calcShannonEnt(dataSet)   # 计算数据集的信息熵    bestInfoGain = 0.0      # 初始化最好的信息熵    bestFeature = -1        # 初始化最好的用于分割的特征    for i in range(numFeatures):        # 创建唯一的分类标签列表        featList= [example[i] for example in dataSet]   # 获取每个元素的第i个特征        uniqueVals = set(featList)  # 数据特征去重 (此特征有几种情况)        newEntropy = 0.0        # 计算每种划分方式的信息熵        for value in uniqueVals:            subDataSet = splitDataSet(dataSet, i, value)            prob = len(subDataSet) / float(len(dataSet))    # probability,概率,可理解为权重            newEntropy += prob * calcShannonEnt(subDataSet)        infoGain = baseEntropy - newEntropy     # 新的熵越小即新划分的数据集混乱程度越小,与原熵的差值就越大, 即信息增益就越大        # 计算最好的信息增益        if(infoGain > bestInfoGain):    # 若新的信息增益大于之前的信息增益,则替换            bestInfoGain = infoGain            bestFeature = i     # 表示最好的划分特征的索引值    return bestFeature

2.递归构建决策树

def createTree(dataSet, labels):    """    :param dataSet: 数据集    :param labels: 标签列表, 包含了数据集中的所有特征的标签    :return:    """    classList = [example[-1] for example in dataSet]    # 类别完全相同则停止继续划分    if classList.count(classList[0]) == len(classList):        return classList[0]    # 遍历完所有特征,仍然不能将数据集划分成包含唯一类别的分组时,返回出现次数最多的类别    if len(dataSet[0]) == 1:        return majorityCnt(classList)    bestFeat = chooseBestFeatureToSplit(dataSet)    # 选取的最好特征    bestFeatLabel = labels[bestFeat]    # 最好特征标签    myTree = {bestFeatLabel:{ }}    # 使用字典存储树信息    # 得到列表包含的所有属性值    del(labels[bestFeat])    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]   # 因为下一步传参数时是引用传参        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)    return myTree

3.使用Matplotlib注解绘制树形图

import matplotlib.pyplot as pltimport trees# 定义文本框和箭头格式decisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")    # 设置箭头的样式和背景色arrow_args = dict(arrowstyle="<-")  # 设置箭头的样式def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解    createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',                            xytext=centerPt, textcoords='axes fraction',                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def createPlot(inTree):    fig = plt.figure(1, facecolor='white')  # 设置背景色    fig.clf()   # 清空画布    axprops = dict(xticks=[], yticks=[])    createPlot.axl = plt.subplot(111, frameon=False, **axprops) #表示图中有1行1列,绘图放在第几列, 有无边框    plotTree.totalW = float(trees.getNumLeafs(inTree))    plotTree.totalD = float(trees.getTreeDepth(inTree))    plotTree.xOff = -0.5/plotTree.totalW    plotTree.yOff = 1.0    plotTree(inTree, (0.5, 1.0), ' ')    # plotNode('a decision node', (0.5, 0.5), (0.1, 0.5), decisionNode)   # 第一个坐标是注解的坐标 第二个坐标是点的坐标    # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)    plt.show()def plotMidText(cntrPt, parentPt, txtString): # 在父子节点间填充文本信息    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]    createPlot.axl.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):    #计算宽与高    numLeafs = trees.getNumLeafs(myTree)    depth = trees.getTreeDepth(myTree)    firstStr = list(myTree.keys())[0]   # 根节点    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)    # 标记子节点属性值    plotMidText(cntrPt, parentPt, nodeTxt)    plotNode(firstStr, cntrPt, parentPt, decisionNode)    secondDict = myTree[firstStr]    # 减少y偏移    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            plotTree(secondDict[key], cntrPt, str(key))        else:            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef retrieveTree(i): # 创建树    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]    return listOfTrees[i]

完整代码

trees.py

from math import logimport operatorimport treePlotterdef calcShannonEnt(dataSet):    # 计算给定数据集的香农熵    numEntries = len(dataSet)    labelCounts = {}    # 为所有可能的分类创建字典    for featVec in dataSet:        currentLabel = featVec[-1]        # if currentLabel not in labelCounts.keys():        #     labelCounts[currentLabel] = 0        # labelCounts[currentLabel] += 1        labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1    shannonEnt = 0.0    for key in labelCounts:        # 以2为底求对数        prob = float(labelCounts[key]) / numEntries        shannonEnt -= prob * log(prob, 2)    return shannonEntdef splitDataSet(dataSet, axis, value): # 按照给定特征划分数据集    """    :param dataSet: 待划分的数据集    :param axis: 划分数据集的特征    :param value: 特征的返回值    :return:    """    # 创建新的list对象    retDataSet = []    for featVec in dataSet:        if featVec[axis] == value:  # 抽取            reducedFratVec = featVec[:axis]            reducedFratVec.extend(featVec[axis+1:])            retDataSet.append(reducedFratVec)    return retDataSetdef chooseBestFeatureToSplit(dataSet):  # 选择最好的分类特征    """    :param dataSet: 原数据集    :return: 最好的划分特征的索引值    """    numFeatures = len(dataSet[0]) - 1   # 获取特征数    baseEntropy = calcShannonEnt(dataSet)   # 计算数据集的信息熵    bestInfoGain = 0.0      # 初始化最好的信息熵    bestFeature = -1        # 初始化最好的用于分割的特征    for i in range(numFeatures):        # 创建唯一的分类标签列表        featList= [example[i] for example in dataSet]   # 获取每个元素的第i个特征        uniqueVals = set(featList)  # 数据特征去重 (此特征有几种情况)        newEntropy = 0.0        # 计算每种划分方式的信息熵        for value in uniqueVals:            subDataSet = splitDataSet(dataSet, i, value)            prob = len(subDataSet) / float(len(dataSet))    # probability,概率,可理解为权重            newEntropy += prob * calcShannonEnt(subDataSet)        infoGain = baseEntropy - newEntropy     # 新的熵越小即新划分的数据集混乱程度越小,与原熵的差值就越大, 即信息增益就越大        # 计算最好的信息增益        if(infoGain > bestInfoGain):    # 若新的信息增益大于之前的信息增益,则替换            bestInfoGain = infoGain            bestFeature = i     # 表示最好的划分特征的索引值    return bestFeaturedef majorityCnt(classList): # 多数表决决定叶子节点的分类    """    :param classList: 类别列表    :return: 出现次数最多的分类名称    """    classCount = {}    for vote in classList:  # 统计分类列表中个类别出现的次数        # if vote not in classCount.keys(): classCount[vote] = 0        # classCount[vote] += 1        classCount[vote] = classCount.get(vote, 0) + 1    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)  # 根据出现次数排序    return sortedClassCount[0][0]def createTree(dataSet, labels):    """    :param dataSet: 数据集    :param labels: 标签列表, 包含了数据集中的所有特征的标签    :return:    """    classList = [example[-1] for example in dataSet]    # 类别完全相同则停止继续划分    if classList.count(classList[0]) == len(classList):        return classList[0]    # 遍历完所有特征,仍然不能将数据集划分成包含唯一类别的分组时,返回出现次数最多的类别    if len(dataSet[0]) == 1:        return majorityCnt(classList)    bestFeat = chooseBestFeatureToSplit(dataSet)    # 选取的最好特征    bestFeatLabel = labels[bestFeat]    # 最好特征标签    myTree = {bestFeatLabel:{ }}    # 使用字典存储树信息    # 得到列表包含的所有属性值    del(labels[bestFeat])    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]   # 因为下一步传参数时是引用传参        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)    return myTreedef getNumLeafs(myTree):    numLeafs = 0    firstStr = list(myTree.keys())[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        # 测试节点的数据类型是否为字典        if type(secondDict[key]).__name__ == 'dict':            numLeafs += getNumLeafs(secondDict[key])        else:            numLeafs += 1    return numLeafsdef getTreeDepth(myTree):    maxDepth = 0    firstStr = list(myTree.keys())[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            thisDepth = 1 + getTreeDepth(secondDict[key])        else:            thisDepth = 1        if thisDepth > maxDepth:            maxDepth = thisDepth    return maxDepthdef createDataSet():    # 创建数据集    dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]    labels = ['no surfacing', 'flippers']    return dataSet, labelsdef classify(inputTree, featLabels, testVec):   # 分类器    """    :param inputTree: 树,即数据集    :param featLabels: 特征标签    :param testVec: 待测向量    :return: 类别    """    firstStr = list(inputTree.keys())[0]    # 将标签字符串转换为索引    secondDict = inputTree[firstStr]    featIndex = featLabels.index(firstStr)    for key in secondDict.keys():        if testVec[featIndex] == key:            if type(secondDict[key]).__name__ == 'dict':                classLabel = classify(secondDict[key], featLabels, testVec)     # 若未到叶子节点,则继续往下递归,直到叶子节点            else:                classLabel = secondDict[key]        # 如果已到叶子节点, 则直接取dict当前key的value    return classLabeldef storeTree(inputTree, filename):     # 序列化保存树(分类信息)    import pickle    fw = open(filename, 'wb+')    pickle.dump(inputTree, fw)    fw.close()def grabTree(filename):     # 读取序列化文件    import pickle    fr = open(filename, "rb+")    return pickle.load(fr)if __name__ == "__main__":    myDat, labels = createDataSet()    # myTree = createTree(myDat, labels)    # print(myTree)    print(myDat)    myTree = treePlotter.retrieveTree(0)    print(myTree)    print(classify(myTree, labels, [1, 0]))    print(classify(myTree, labels, [1, 1]))    print("===========store tree============")    storeTree(myTree, 'classifierStorafe.txt')    print(grabTree('classifierStorafe.txt'))

treePlotter

import matplotlib.pyplot as pltimport trees# 定义文本框和箭头格式decisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")    # 设置箭头的样式和背景色arrow_args = dict(arrowstyle="<-")  # 设置箭头的样式def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解    createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',                            xytext=centerPt, textcoords='axes fraction',                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def createPlot(inTree):    fig = plt.figure(1, facecolor='white')  # 设置背景色    fig.clf()   # 清空画布    axprops = dict(xticks=[], yticks=[])    createPlot.axl = plt.subplot(111, frameon=False, **axprops) #表示图中有1行1列,绘图放在第几列, 有无边框    plotTree.totalW = float(trees.getNumLeafs(inTree))    plotTree.totalD = float(trees.getTreeDepth(inTree))    plotTree.xOff = -0.5/plotTree.totalW    plotTree.yOff = 1.0    plotTree(inTree, (0.5, 1.0), ' ')    # plotNode('a decision node', (0.5, 0.5), (0.1, 0.5), decisionNode)   # 第一个坐标是注解的坐标 第二个坐标是点的坐标    # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)    plt.show()def plotMidText(cntrPt, parentPt, txtString): # 在父子节点间填充文本信息    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]    createPlot.axl.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):    #计算宽与高    numLeafs = trees.getNumLeafs(myTree)    depth = trees.getTreeDepth(myTree)    firstStr = list(myTree.keys())[0]   # 根节点    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)    # 标记子节点属性值    plotMidText(cntrPt, parentPt, nodeTxt)    plotNode(firstStr, cntrPt, parentPt, decisionNode)    secondDict = myTree[firstStr]    # 减少y偏移    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            plotTree(secondDict[key], cntrPt, str(key))        else:            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef retrieveTree(i): # 创建树    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]    return listOfTrees[i]if __name__ == "__main__":    # reTree = retrieveTree(1)    # leafs = trees.getNumLeafs(reTree)    # depth = trees.getTreeDepth(reTree)    # print(reTree)    # print(leafs)    # print(depth)    myTree = retrieveTree(0)    myTree['no surfacing'][3] = 'maybe'    createPlot(myTree)
阅读全文
'); })();
0 0
原创粉丝点击
热门IT博客
热门问题 larrycms sp 完美世界 MazePackage php woainixss\"0 Androidlog cca cookie ecshop saleae.logic PackageControl 彩票挂机软件 SAP工资 mvr蒸发器设计软件 uva325 QQ QQ密码 552893585486fgh Fgh TQ2440烧写bin TQ2440裸机 TQ2440led TQ2440led TQ2440点灯 stitching介绍 pythonstitching stitchingpython e-mailenglisg e-mailenglish c语言星行三角 c语言星型三角 大掌门 vfp获取天气 TwoPermutations 发布订阅 发布订阅模式 eeee 大富翁python Sicily guoshanche Sicily过山车 SALESFORCE D Avengerendgame Avengerendgame 水晶報表 Granfurniture Granfurniture Granfurniture brokenimg