轉載,原文地址:http://www.cnblogs.com/fantasy01/p/4595902.htmlhtml
在看機器學習實戰時候,到第三章的對決策樹畫圖的時候,有一段遞歸函數怎麼都看不懂,由於之後想選這個方向爲本身的職業導向,抱着精看的態度,對這本樹進行地毯式掃描,因此就沒跳過,一直卡了一天多,才差很少搞懂,纔對那個函數中的plotTree.xOff的取值,以及計算cntrPt的方法搞懂,相信也有人和我同樣,但願可以相互交流。node
先把代碼貼在這裏:算法
1 import matplotlib.pyplot as plt 2
3 #這裏是對繪製是圖形屬性的一些定義,能夠不用管,主要是後面的算法
4 decisionNode = dict(boxstyle="sawtooth", fc="0.8") 5 leafNode = dict(boxstyle="round4", fc="0.8") 6 arrow_args = dict(arrowstyle="<-") 7
8 #這是遞歸計算樹的葉子節點個數,比較簡單
9 def getNumLeafs(myTree): 10 numLeafs = 0 11 firstStr = myTree.keys()[0] 12 secondDict = myTree[firstStr] 13 for key in secondDict.keys(): 14 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
15 numLeafs += getNumLeafs(secondDict[key]) 16 else: numLeafs +=1
17 return numLeafs 18
19 #這是遞歸計算樹的深度,比較簡單
20 def getTreeDepth(myTree): 21 maxDepth = 0 22 firstStr = myTree.keys()[0] 23 secondDict = myTree[firstStr] 24 for key in secondDict.keys(): 25 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
26 thisDepth = 1 + getTreeDepth(secondDict[key]) 27 else: thisDepth = 1
28 if thisDepth > maxDepth: maxDepth = thisDepth 29 return maxDepth 30
31 #這個是用來一註釋形式繪製節點和箭頭線,能夠不用管
32 def plotNode(nodeTxt, centerPt, parentPt, nodeType): 33 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', 34 xytext=centerPt, textcoords='axes fraction', 35 va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) 36
37 #這個是用來繪製線上的標註,簡單
38 def plotMidText(cntrPt, parentPt, txtString): 39 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] 40 yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] 41 createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) 42
43 #重點,遞歸,決定整個樹圖的繪製,難(本身認爲)
44 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
45 numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
46 depth = getTreeDepth(myTree) 47 firstStr = myTree.keys()[0] #the text label for this node should be this
48 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) 49
50 plotMidText(cntrPt, parentPt, nodeTxt) 51 plotNode(firstStr, cntrPt, parentPt, decisionNode) 52 secondDict = myTree[firstStr] 53 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD 54 for key in secondDict.keys(): 55 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
56 plotTree(secondDict[key],cntrPt,str(key)) #recursion
57 else: #it's a leaf node print the leaf node
58 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW 59 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) 60 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) 61 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD 62 #if you do get a dictonary you know it's a tree, and the first element will be another dict
63
64 #這個是真正的繪製,上邊是邏輯的繪製
65 def createPlot(inTree): 66 fig = plt.figure(1, facecolor='white') 67 fig.clf() 68 axprops = dict(xticks=[], yticks=[]) 69 createPlot.ax1 = plt.subplot(111, frameon=False) #no ticks
70 plotTree.totalW = float(getNumLeafs(inTree)) 71 plotTree.totalD = float(getTreeDepth(inTree)) 72 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; 73 plotTree(inTree, (0.5,1.0), '') 74 plt.show() 75
76 #這個是用來建立數據集即決策樹
77 def retrieveTree(i): 78 listOfTrees =[{'no surfacing': {0:{'flippers': {0: 'no', 1: 'yes'}}, 1: {'flippers': {0: 'no', 1: 'yes'}}, 2:{'flippers': {0: 'no', 1: 'yes'}}}}, 79 {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} 80 ] 81 return listOfTrees[i] 82
83 createPlot(retrieveTree(0))
繪製出來的圖形以下:機器學習
先導:這裏說一下爲何說一個遞歸樹的繪製爲何會是很難懂,這裏不就是利用遞歸函數來繪圖麼,就如遞歸計算樹的深度、葉子節點同樣,問題不是遞歸的思路,而是這本書中一些座標的起始取值、以及在計算節點座標所做的處理,並且在樹中對這部分並無取講述,因此在看這段代碼的時候可能大致思路明白可是具體的細節卻知之甚少,因此本篇主要是對其中書中說起甚少的做詳細的講述,固然代碼的總體思路也不會放過的函數
準備:這裏說一下具體繪製的時候是利用自定義plotNode函數來繪製,這個函數一次繪製的是一個箭頭和一個節點,以下圖:學習
思路:這裏繪圖,做者選取了一個很聰明的方式,並不會由於樹的節點的增減和深度的增減而致使繪製出來的圖形出現問題,固然不能太密集。這裏利用整棵樹的葉子節點數做爲份數將整個x軸的長度進行平均切分,利用樹的深度做爲份數將y軸長度做平均切分,並利用plotTree.xOff做爲最近繪製的一個葉子節點的x座標,當再一次繪製葉子節點座標的時候纔會plotTree.xOff纔會發生改變;用plotTree.yOff做爲當前繪製的深度,plotTree.yOff是在每遞歸一層就會減一份(上邊所說的按份平均切分),其餘時候是利用這兩個座標點去計算非葉子節點,這兩個參數其實就能夠肯定一個點座標,這個座標肯定的時候就是繪製節點的時候this
總體算法的遞歸思路卻是很容易理解:spa
每一次都分三個步驟:code
(1)繪製自身htm
(2)判斷子節點非葉子節點,遞歸
(3)判斷子節點爲葉子節點,繪製
詳細解析:
1 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
2 numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
3 depth = getTreeDepth(myTree) 4 firstStr = myTree.keys()[0] #the text label for this node should be this
5 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) 6
7 plotMidText(cntrPt, parentPt, nodeTxt) 8 plotNode(firstStr, cntrPt, parentPt, decisionNode) 9 secondDict = myTree[firstStr] 10 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD 11 for key in secondDict.keys(): 12 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
13 plotTree(secondDict[key],cntrPt,str(key)) #recursion
14 else: #it's a leaf node print the leaf node
15 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW 16 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) 17 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) 18 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD 19 #if you do get a dictonary you know it's a tree, and the first element will be another dict
20
21 def createPlot(inTree): 22 fig = plt.figure(1, facecolor='white') 23 fig.clf() 24 axprops = dict(xticks=[], yticks=[]) 25 createPlot.ax1 = plt.subplot(111, frameon=False) #no ticks
26 plotTree.totalW = float(getNumLeafs(inTree)) 27 plotTree.totalD = float(getTreeDepth(inTree)) 28 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;#totalW爲整樹的葉子節點樹,totalD爲深度
29 plotTree(inTree, (0.5,1.0), '') 30 plt.show()
上邊代碼中紅色部分如此處理原理:
首先因爲整個畫布根據葉子節點數和深度進行平均切分,而且x軸的總長度爲1,即如同下圖:
一、其中方形爲非葉子節點的位置,@是葉子節點的位置,所以每份即上圖的一個表格的長度應該爲1/plotTree.totalW,可是葉子節點的位置應該爲@所在位置,則在開始的時候plotTree.xOff的賦值爲-0.5/plotTree.totalW,即意爲開始x位置爲第一個表格左邊的半個表格距離位置,這樣做的好處爲:在之後肯定@位置時候能夠直接加整數倍的1/plotTree.totalW,
二、對於plotTree函數中的紅色部分即以下:
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotTree.xOff即爲最近繪製的一個葉子節點的x座標,在肯定當前節點位置時每次只需肯定當前節點有幾個葉子節點,所以其葉子節點所佔的總距離就肯定了即爲float(numLeafs)/plotTree.totalW*1(由於總長度爲1),所以當前節點的位置即爲其全部葉子節點所佔距離的中間即一半爲float(numLeafs)/2.0/plotTree.totalW*1,可是因爲開始plotTree.xOff賦值並不是從0開始,而是左移了半個表格,所以還需加上半個表格距離即爲1/2/plotTree.totalW*1,則加起來便爲(1.0 + float(numLeafs))/2.0/plotTree.totalW*1,所以偏移量肯定,則x位置變爲plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW
三、對於plotTree函數參數賦值爲(0.5, 1.0)
由於開始的根節點並不用劃線,所以父節點和當前節點的位置須要重合,利用2中的肯定當前節點的位置便爲(0.5, 1.0)
總結:利用這樣的逐漸增長x的座標,以及逐漸下降y的座標能可以很好的將樹的葉子節點數和深度考慮進去,所以圖的邏輯比例就很好的肯定了,這樣不用去關心輸出圖形的大小,一旦圖形發生變化,函數會從新繪製,可是假如利用像素爲單位來繪製圖形,這樣縮放圖形就比較有難度了