决策树剪枝算法的python实现方法详解
本文实例讲述了决策树剪枝算法的python实现方法。分享给大家供大家参考,具体如下:
决策树是一种依托决策而建立起来的一种树。在机器学习中,决策树是一种预测模型,代表的是一种对象属性与对象值之间的一种映射关系,每一个节点代表某个对象,树中的每一个分叉路径代表某个可能的属性值,而每一个叶子节点则对应从根节点到该叶子节点所经历的路径所表示的对象的值。决策树仅有单一输出,如果有多个输出,可以分别建立独立的决策树以处理不同的输出。
ID3算法:ID3算法是决策树的一种,是基于奥卡姆剃刀原理的,即用尽量用较少的东西做更多的事。ID3算法,即Iterative Dichotomiser 3,迭代二叉树3代,是Ross Quinlan发明的一种决策树算法,这个算法的基础就是上面提到的奥卡姆剃刀原理,越是小型的决策树越优于大的决策树,尽管如此,也不总是生成最小的树型结构,而是一个启发式算法。在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍历可能的决策空间。
信息熵,将其定义为离散随机事件出现的概率,一个系统越是有序,信息熵就越低,反之一个系统越是混乱,它的信息熵就越高。所以信息熵可以被认为是系统有序化程度的一个度量。
基尼指数:在CART里面划分决策树的条件是采用Gini Index,定义如下:gini(T)=1"htmlcode">
def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} # 给所有可能分类创建字典 for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 # 以2为底数计算香农熵 for key in labelCounts: prob = float(labelCounts[key]) / numEntries shannonEnt -= prob * log(prob, 2) return shannonEnt
# 对离散变量划分数据集,取出该特征取值为value的所有样本 def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis + 1:]) retDataSet.append(reducedFeatVec) return retDataSet
对连续变量划分数据集,direction规定划分的方向, 决定是划分出小于value的数据样本还是大于value的数据样本集
numFeatures = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0 bestFeature = -1 bestSplitDict = {} for i in range(numFeatures): featList = [example[i] for example in dataSet] # 对连续型特征进行处理 if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int': # 产生n-1个候选划分点 sortfeatList = sorted(featList) splitList = [] for j in range(len(sortfeatList) - 1): splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0) bestSplitEntropy = 10000 slen = len(splitList) # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点 for j in range(slen): value = splitList[j] newEntropy = 0.0 subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0) subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1) prob0 = len(subDataSet0) / float(len(dataSet)) newEntropy += prob0 * calcShannonEnt(subDataSet0) prob1 = len(subDataSet1) / float(len(dataSet)) newEntropy += prob1 * calcShannonEnt(subDataSet1) if newEntropy < bestSplitEntropy: bestSplitEntropy = newEntropy bestSplit = j # 用字典记录当前特征的最佳划分点 bestSplitDict[labels[i]] = splitList[bestSplit] infoGain = baseEntropy - bestSplitEntropy # 对离散型特征进行处理 else: uniqueVals = set(featList) newEntropy = 0.0 # 计算该特征下每种划分的信息熵 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy if infoGain > bestInfoGain: bestInfoGain = infoGain bestFeature = i # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理 # 即是否小于等于bestSplitValue if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int': bestSplitValue = bestSplitDict[labels[bestFeature]] labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue) for i in range(shape(dataSet)[0]): if dataSet[i][bestFeature] <= bestSplitValue: dataSet[i][bestFeature] = 1 else: dataSet[i][bestFeature] = 0 return bestFeature
def chooseBestFeatureToSplit(dataSet, labels): numFeatures = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0 bestFeature = -1 bestSplitDict = {} for i in range(numFeatures): featList = [example[i] for example in dataSet] # 对连续型特征进行处理 if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int': # 产生n-1个候选划分点 sortfeatList = sorted(featList) splitList = [] for j in range(len(sortfeatList) - 1): splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0) bestSplitEntropy = 10000 slen = len(splitList) # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点 for j in range(slen): value = splitList[j] newEntropy = 0.0 subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0) subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1) prob0 = len(subDataSet0) / float(len(dataSet)) newEntropy += prob0 * calcShannonEnt(subDataSet0) prob1 = len(subDataSet1) / float(len(dataSet)) newEntropy += prob1 * calcShannonEnt(subDataSet1) if newEntropy < bestSplitEntropy: bestSplitEntropy = newEntropy bestSplit = j # 用字典记录当前特征的最佳划分点 bestSplitDict[labels[i]] = splitList[bestSplit] infoGain = baseEntropy - bestSplitEntropy # 对离散型特征进行处理 else: uniqueVals = set(featList) newEntropy = 0.0 # 计算该特征下每种划分的信息熵 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy if infoGain > bestInfoGain: bestInfoGain = infoGain bestFeature = i # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理 # 即是否小于等于bestSplitValue if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int': bestSplitValue = bestSplitDict[labels[bestFeature]] labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue) for i in range(shape(dataSet)[0]): if dataSet[i][bestFeature] <= bestSplitValue: dataSet[i][bestFeature] = 1 else: dataSet[i][bestFeature] = 0 return bestFeature ``def classify(inputTree, featLabels, testVec): firstStr = inputTree.keys()[0] if u'<=' in firstStr: featvalue = float(firstStr.split(u"<=")[1]) featkey = firstStr.split(u"<=")[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(featkey) if testVec[featIndex] <= featvalue: judge = 1 else: judge = 0 for key in secondDict.keys(): if judge == int(key): if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key], featLabels, testVec) else: classLabel = secondDict[key] else: secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) for key in secondDict.keys(): if testVec[featIndex] == key: if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key], featLabels, testVec) else: classLabel = secondDict[key] return classLabel
def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote]=0 classCount[vote]+=1 return max(classCount) def testing_feat(feat, train_data, test_data, labels): class_list = [example[-1] for example in train_data] bestFeatIndex = labels.index(feat) train_data = [example[bestFeatIndex] for example in train_data] test_data = [(example[bestFeatIndex], example[-1]) for example in test_data] all_feat = set(train_data) error = 0.0 for value in all_feat: class_feat = [class_list[i] for i in range(len(class_list)) if train_data[i] == value] major = majorityCnt(class_feat) for data in test_data: if data[0] == value and data[1] != major: error += 1.0 # print 'myTree %d' % error return error
测试
error = 0.0 for i in range(len(data_test)): if classify(myTree, labels, data_test[i]) != data_test[i][-1]: error += 1 # print 'myTree %d' % error return float(error) def testingMajor(major, data_test): error = 0.0 for i in range(len(data_test)): if major != data_test[i][-1]: error += 1 # print 'major %d' % error return float(error) **递归产生决策树** ```def createTree(dataSet,labels,data_full,labels_full,test_data,mode): classList=[example[-1] for example in dataSet] if classList.count(classList[0])==len(classList): return classList[0] if len(dataSet[0])==1: return majorityCnt(classList) labels_copy = copy.deepcopy(labels) bestFeat=chooseBestFeatureToSplit(dataSet,labels) bestFeatLabel=labels[bestFeat] if mode == "unpro" or mode == "post": myTree = {bestFeatLabel: {}} elif mode == "prev": if testing_feat(bestFeatLabel, dataSet, test_data, labels_copy) < testingMajor(majorityCnt(classList), test_data): myTree = {bestFeatLabel: {}} else: return majorityCnt(classList) featValues=[example[bestFeat] for example in dataSet] uniqueVals=set(featValues) if type(dataSet[0][bestFeat]).__name__ == 'unicode': currentlabel = labels_full.index(labels[bestFeat]) featValuesFull = [example[currentlabel] for example in data_full] uniqueValsFull = set(featValuesFull) del (labels[bestFeat]) for value in uniqueVals: subLabels = labels[:] if type(dataSet[0][bestFeat]).__name__ == 'unicode': uniqueValsFull.remove(value) myTree[bestFeatLabel][value] = createTree(splitDataSet (dataSet, bestFeat, value), subLabels, data_full, labels_full, splitDataSet (test_data, bestFeat, value), mode=mode) if type(dataSet[0][bestFeat]).__name__ == 'unicode': for value in uniqueValsFull: myTree[bestFeatLabel][value] = majorityCnt(classList) if mode == "post": if testing(myTree, test_data, labels_copy) > testingMajor(majorityCnt(classList), test_data): return majorityCnt(classList) return myTree <div class="se-preview-section-delimiter"></div> ```**读入数据** ```def load_data(file_name): with open(r"dd.csv", 'rb') as f: df = pd.read_csv(f,sep=",") print(df) train_data = df.values[:11, 1:].tolist() print(train_data) test_data = df.values[11:, 1:].tolist() labels = df.columns.values[1:-1].tolist() return train_data, test_data, labels <div class="se-preview-section-delimiter"></div> ```测试并绘制树图 import matplotlib.pyplot as plt decisionNode = dict(boxstyle="round4", color='red') # 定义判断结点形态 leafNode = dict(boxstyle="circle", color='grey') # 定义叶结点形态 arrow_args = dict(arrowstyle="<-", color='blue') # 定义箭头 # 计算树的叶子节点数量 def getNumLeafs(myTree): numLeafs = 0 firstSides = list(myTree.keys()) firstStr = firstSides[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs # 计算树的最大深度 def getTreeDepth(myTree): maxDepth = 0 firstSides = list(myTree.keys()) firstStr = firstSides[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth # 画节点 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) # 画箭头上的文字 def plotMidText(cntrPt, parentPt, txtString): lens = len(txtString) xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002 yMid = (parentPt[1] + cntrPt[1]) / 2.0 createPlot.ax1.text(xMid, yMid, txtString) def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstSides = list(myTree.keys()) firstStr = firstSides[0] cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], cntrPt, str(key)) else: plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode) plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key)) plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.x0ff = -0.5 / plotTree.totalW plotTree.y0ff = 1.0 plotTree(inTree, (0.5, 1.0), '') plt.show()
if __name__ == "__main__": train_data, test_data, labels = load_data("dd.csv") data_full = train_data[:] labels_full = labels[:] mode="post" mode = "prev" mode="post" myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode) createPlot(myTree) print(json.dumps(myTree, ensure_ascii=False, indent=4))
选择mode就可以分别得到三种树图
if __name__ == "__main__": train_data, test_data, labels = load_data("dd.csv") data_full = train_data[:] labels_full = labels[:] mode="post" mode = "prev" mode="post" myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode) createPlot(myTree) print(json.dumps(myTree, ensure_ascii=False, indent=4))
更多关于Python相关内容感兴趣的读者可查看本站专题:《Python数据结构与算法教程》、《Python加密解密算法与技巧总结》、《Python编码操作技巧总结》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》
希望本文所述对大家Python程序设计有所帮助。
下一篇:python连接PostgreSQL数据库的过程详解