參照上一篇ID3算法實現的決策樹(點擊上面連接直達),進一步實現CART決策樹。html
其實只須要改動很小的一部分就能夠了,把原先計算信息熵和信息增益的部分換作計算基尼指數,選擇最優屬性的時候,選擇最小的基尼指數便可。算法
#導入模塊 import pandas as pd import numpy as np from collections import Counter #數據獲取與處理 def getData(filePath): data = pd.read_excel(filePath) return data def dataDeal(data): dataList = np.array(data).tolist() dataSet = [element[1:] for element in dataList] return dataSet #獲取屬性名稱 def getLabels(data): labels = list(data.columns)[1:-1] return labels #獲取類別標記 def targetClass(dataSet): classification = set([element[-1] for element in dataSet]) return classification #將分支結點標記爲葉結點,選擇樣本數最多的類做爲類標記 def majorityRule(dataSet): mostKind = Counter([element[-1] for element in dataSet]).most_common(1) majorityKind = mostKind[0][0] return majorityKind ##計算基尼值 def calculateGini(dataSet): classColumnCnt = Counter([element[-1] for element in dataSet]) gini = 0 for symbol in classColumnCnt: p_k = classColumnCnt[symbol]/len(dataSet) gini = gini+p_k**2 gini = 1-gini return gini #子數據集構建 def makeAttributeData(dataSet,value,iColumn): attributeData = [] for element in dataSet: if element[iColumn]==value: row = element[:iColumn] row.extend(element[iColumn+1:]) attributeData.append(row) return attributeData #計算基尼指數 def GiniIndex(dataSet,iColumn): index = 0.0 attribute = set([element[iColumn] for element in dataSet]) for value in attribute: attributeData = makeAttributeData(dataSet,value,iColumn) index = index+len(attributeData)/len(dataSet)*calculateGini(attributeData) return index #選擇最優屬性 def selectOptimalAttribute(dataSet,labels): bestGini = [] for iColumn in range(0,len(labels)):#不計最後的類別列 index = GiniIndex(dataSet,iColumn) bestGini.append(index) sequence = bestGini.index(min(bestGini)) return sequence #創建決策樹 def createTree(dataSet,labels): classification = targetClass(dataSet) #獲取類別種類(集合去重) if len(classification) == 1: return list(classification)[0] if len(labels) == 1: return majorityRule(dataSet)#返回樣本種類較多的類別 sequence = selectOptimalAttribute(dataSet,labels) optimalAttribute = labels[sequence] del(labels[sequence]) myTree = {optimalAttribute:{}} attribute = set([element[sequence] for element in dataSet]) for value in attribute: subLabels = labels[:] myTree[optimalAttribute][value] = \ createTree(makeAttributeData(dataSet,value,sequence),subLabels) return myTree #定義主函數 def main(): filePath = 'watermelonData.xls' data = getData(filePath) dataSet = dataDeal(data) labels = getLabels(data) myTree = createTree(dataSet,labels) return myTree #讀取數據文件並轉換爲列表(含有漢字的,使用CSV格式讀取容易出錯) if __name__ == '__main__': myTree = main() print (myTree)
結果居然是同樣的,深度懷疑作錯了。app