在構造決策樹時,咱們須要解決的第一個問題就是,當前數據集上哪一個特徵在劃分數據分類時起決定性做用。爲了找到決定性特徵,劃分出最好的結果,咱們必須評估每一個特徵。完成測試以後,原始數據集就被劃分爲幾個數據子集。這些數據子集會分佈在第一個決策點的全部分支上。若是某個分支下的數據屬於同一類型,則此分支無需繼續劃分。若是數據子集內的數據不屬於同一類,則須要重複劃分數據子集。劃分數據子集的算法和劃分原始數據集的方法相同,直到所具備相同數據類型的數據均在一個數據子集內。html |
建立分支的僞代碼函數 createBranch() 以下:html5
Check if every item in the dataset is in the same class: If so return the class label Else find the best feature to split the data split the dataset create a branch node for each split call createBranch and add the result to the branch node return branch node
訓練集:$D={(x_1,y_1), (x_2,y_2), ..., (x_m, y_m)}$python
屬性集:$A={a_1, a_2, ..., a_d}$linux
def createDataSet(): dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] feature = ['no surfacing', 'flippers'] return dataSet, feature
$a_{\star} = \arg\max\limits_{a\in{A}} Gain(D, a)$web
劃分數據集的大原則是:將無序的數據變得更加有序。在劃分數據集先後信息發生的變化稱爲信息增益,咱們計算每一個特徵值劃分數據集得到信息增益,得到的信息增益最高的特徵就是最好的劃分特徵。集合信息的度量方式稱爲熵,假定當前樣本集合$D$中第$k$類樣本所佔比例爲 $p_k \ (k=1,2,...,\lvert{y}\rvert)$,則$D$的information entropy是:算法
$Ent(D) = - \sum_{k=1}^{\lvert{y}\rvert}\ p_k log_2^{p_k}$shell
那麼對於$D$的各個結點$D_v$,咱們能夠算出$D_v$的information entropy,再考慮到不一樣的分支結點所包含的樣本數不均勻,給分支賦予權重$\frac{\lvert{D_v}\rvert}{\lvert{D}\rvert}$,這樣獲得information gain:api
$Gain(D,a_{\star}) = Ent(D) - \sum_{v=1}^{V} \frac{\lvert{D_v}\rvert}{\lvert{D}\rvert}Ent(D_v)$數據結構
# 1. 計算信息熵 from math import log def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] # 讀取當前樣本標籤 if currentLabel not in labelCounts.keys(): # 檢查字典中是否存在該標籤 labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries # 計算類別機率 shannonEnt -= prob * log(prob,2) # 計算信息熵 return shannonEnt
# 2. 劃分數據集,參數: 數據集, 劃分屬性, 子節點屬性值 def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) # extend 按元素添加進去 retDataSet.append(reducedFeatVec) # append 按總體添加進去 return retDataSet
# 3. 選擇最好的數據集劃分方式 def chooseBestFeatureToSplit(dataSet): numFeayures = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) # 根結點信息熵 bestInfoGain = 0.0; bestFeature = -1 # 初始化信息增益和最優分裂屬性 for i in range(numFeayures): # 對於每一個屬性,計算信息增益,得到最大信息增益對應的分裂屬性 featList = [example[i] for example in dataSet] # 收集全部樣本在該屬性上的值 uniqueVals = set(featList) # 集合統計共有幾種屬性值 newEntropy = 0.0 for value in uniqueVals: # 對於分裂屬性的每一個取值的可能,劃分子集計算信息熵 subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy if infoGain > bestInfoGain: bestInfoGain = infoGain bestFeature = i return bestFeature
# 4. 採用多數表決的方法決定葉子節點的類別 import operator def majorityCnt(classList): classCount = {} # 類別以字典形式統計 for vote in classList: if vote not in classCount.keys():classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) # 排序字典 return sortedClassCount[0][0]
# 5. 建立決策樹代碼,輸入數據集和特徵/屬性類型,輸出字典型模型 def createTree(dataSet, features): classList = [example[-1] for example in dataSet] # 得到當前數據全部類別if classList.count(classList[0]) == len(classList): # 若是類別徹底相同則中止劃分,返回該類別 return classList[0] if len(dataSet[0]) == 1: # 若是劃分屬性已經所有用完,中止劃分,返回該節點類別最多的類 return majorityCnt(classList) bestFeatureIndex = chooseBestFeatureToSplit(dataSet) # 選擇最優屬性劃分 bestFeature = features[bestFeatureIndex] myTree = {bestFeature:{}} # 構建模型字典,字典的key是最優屬性,字典的value預約義爲一個字典 featValues = [example[bestFeatureIndex] for example in dataSet] # 收集全部樣本在該屬性上的值 uniqueVals = set(featValues) # 集合統計共有幾種屬性值,對應幾個子集 for value in uniqueVals: subFeatures = features[:] # 爲了避免改變features列表變量,咱們使用新變量代替 del (subFeatures[bestFeatureIndex]) # 刪除已用劃分屬性 myTree[bestFeature][value] = createTree(splitDataSet(dataSet, bestFeatureIndex, value), subFeatures) # 繼續分裂,若是沒法分裂返回類別 return myTree
下一步程序開始建立數,這裏python使用字典類型存儲樹的信息。字典變量存儲了數的全部信息,這對於隨後的繪製數形圖很是重要。由於在python中,函數參數是列表類型時,參數是按照引用形式傳遞的,爲了避免改變原始類別標籤,程序中使用 subFeatures = features[:] 複製了類別標籤。
myData, feature = createDataSet() myTree = createTree(myData, feature) print myTree 輸出: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
變量myTree包含了不少表明樹結構信息的嵌套字典,從左邊開始,第一個關鍵字 no surfacing 是第一個劃分屬性,該關鍵字值也是另外一個數據字典。後面這個數據字典是根據 no surfacing特徵劃分的數據子集,這些關鍵字表明的是各個分支,關鍵字值多是類標籤,也多是另外一個數據字典。若是值是類標籤,則該節點爲葉子節點;若是值是另外一個數據字典,則該節點是一個判斷節點,這種格式不斷重複構成了整棵樹。
# coding:utf-8 import matplotlib.pyplot as plt decisionNode = dict(boxstyle="sawtooth", fc='0.8') # 決策節點描述 leafNode = dict(boxstyle='round4', fc='0.8') # 葉子節點描述 arrow_args = dict(arrowstyle="<-") # 箭頭描述 # 繪製節點四個參數: 節點描述,節點中心座標, 父節點中心座標, 節點類型 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) def createPlot(): fig = plt.figure(1, facecolor='white') # 建立一個新圖像 fig.clf() # 清空繪圖區 createPlot.ax1 = plt.subplot(111, frameon=False) plotNode(u'決策節點', (0.5, 0.1), (0.1, 0.5), decisionNode) plotNode(u'葉節點', (0.8, 0.1), (0.3, 0.8), leafNode) plt.show() createPlot()
# 2.1. 統計葉子節點數 def getNumLeafs(myTree): numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] # 若是子節value是字典型則該節點是決策節點,不然是葉子節點 for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs # 2.2. 統計數的深度 def getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] # 若是子節value是字典型則該節點是決策節點,不然是葉子節點 for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth
def retrieveTree(i): listOfTree = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} ] return listOfTree[i] myTree = retrieveTree(1) print getNumLeafs(myTree) # 輸出4 print getTreeDepth(myTree) # 輸出3
# 4. 繪製一棵樹 def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) getTreeDepth(myTree) firstStr = myTree.keys()[0] cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': plotTree(secondDict[key],cntrPt,str(key)) else: plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD # 5. 建立畫板 def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) # 設置繪圖板寬度 plotTree.totalD = float(getTreeDepth(inTree)) # 設置繪圖板高度 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show()
myTree = retrieveTree(0)
# 6. 使用決策樹的分類函數 def classify(inputTree, featLabels, testVec): firstStr = inputTree.keys()[0] # 得到父節點分裂屬性 secondDict = inputTree[firstStr] # 得到子集 featIndex = featLabels.index(firstStr) # 分裂屬性在屬性列表中序號 # 對於每一個子節點,若是是葉子節點返回分類結果,若是是決策節點,遞歸調用分類器 for key in secondDict.keys(): if testVec[featIndex] == key: if type(secondDict[key]).__name__=='dict': classLabel = classify(secondDict[key], featLabels, testVec) else: classLabel = secondDict[key] return classLabel myData, feature = createDataSet() # 得到訓練樣本 print feature myTree = createTree(myData, feature) # 得到決策樹模型模型 print myTree predict = classify(myTree, feature, [1, 0]) # 預測 print predict
# 7.1 保存模型 def storeTree(inputTree, filename): import pickle fw = open(filename, 'w') # 打開或新建文本文件 pickle.dump(inputTree, fw) # 保存 fw.close() # 7.2 加載模型 def grabTree(filename): import pickle fr = open(filename) return pickle.load(fr)
Example: using decision trees to predict contact lens type 1. Collect: Text file provided. 2. Prepare: Parse tab-delimited lines. 3. Analyze: Quickly review data visually to make sure it was parsed properly. The final tree will be plotted with createPlot(). 4. Train: Use createTree() from section 3.1. 5. Test: Write a function to descend the tree for a given instance. 6. Use: Persist the tree data structure so it can be recalled without building the tree; then use it in any application.
# 8. 使用決策樹預測隱形眼鏡類型 import treePlotter fr = open('lenses.txt') lenses = [inst.strip().split('\t') for inst in fr.readlines()] lensesFeatures = ['age', 'prescript', 'astigmatic', 'tearRate'] lensesTree = createTree(lenses, lensesFeatures) treePlotter.createPlot(lensesTree)