(二)《機器學習》(周志華)第4章 決策樹 筆記 理論及實現——「西瓜樹」——CART決策樹

CART決策樹

(一)《機器學習》(周志華)第4章 決策樹 筆記 理論及實現——「西瓜樹」

參照上一篇ID3算法實現的決策樹(點擊上面連接直達),進一步實現CART決策樹。html

其實只須要改動很小的一部分就能夠了,把原先計算信息熵和信息增益的部分換作計算基尼指數,選擇最優屬性的時候,選擇最小的基尼指數便可。算法

#導入模塊
import pandas as pd
import numpy as np
from collections import Counter

#數據獲取與處理
def getData(filePath):
    data = pd.read_excel(filePath)
    return data

def dataDeal(data):
    dataList = np.array(data).tolist()
    dataSet = [element[1:] for element in dataList]
    return dataSet

#獲取屬性名稱
def getLabels(data):
    labels = list(data.columns)[1:-1]
    return labels

#獲取類別標記
def targetClass(dataSet):
    classification = set([element[-1] for element in dataSet])
    return classification
    
#將分支結點標記爲葉結點,選擇樣本數最多的類做爲類標記
def majorityRule(dataSet):
    mostKind = Counter([element[-1] for element in dataSet]).most_common(1)
    majorityKind = mostKind[0][0]
    return majorityKind

##計算基尼值
def calculateGini(dataSet):
    classColumnCnt = Counter([element[-1] for element in dataSet])
    gini = 0
    for symbol in classColumnCnt:
        p_k = classColumnCnt[symbol]/len(dataSet)
        gini = gini+p_k**2
    gini = 1-gini
    return gini

#子數據集構建
def makeAttributeData(dataSet,value,iColumn):
    attributeData = []
    for element in dataSet:
        if element[iColumn]==value:
            row = element[:iColumn]
            row.extend(element[iColumn+1:])
            attributeData.append(row)
    return attributeData

#計算基尼指數
def GiniIndex(dataSet,iColumn):
    index = 0.0
    attribute = set([element[iColumn] for element in dataSet])
    for value in attribute:
        attributeData = makeAttributeData(dataSet,value,iColumn)
        index = index+len(attributeData)/len(dataSet)*calculateGini(attributeData)
    return index

#選擇最優屬性                
def selectOptimalAttribute(dataSet,labels):
    bestGini = []
    for iColumn in range(0,len(labels)):#不計最後的類別列
        index = GiniIndex(dataSet,iColumn)
        bestGini.append(index)
    sequence = bestGini.index(min(bestGini))
    return sequence
    
#創建決策樹
def createTree(dataSet,labels):
    classification = targetClass(dataSet) #獲取類別種類(集合去重)
    if len(classification) == 1:
        return list(classification)[0]
    if len(labels) == 1:
        return majorityRule(dataSet)#返回樣本種類較多的類別
    sequence = selectOptimalAttribute(dataSet,labels)
    optimalAttribute = labels[sequence]
    del(labels[sequence])
    myTree = {optimalAttribute:{}}
    attribute = set([element[sequence] for element in dataSet])
    for value in attribute:
        subLabels = labels[:]
        myTree[optimalAttribute][value] =  \
                createTree(makeAttributeData(dataSet,value,sequence),subLabels)
    return myTree

#定義主函數
def main():
    filePath = 'watermelonData.xls'
    data = getData(filePath)
    dataSet = dataDeal(data)
    labels = getLabels(data)
    myTree = createTree(dataSet,labels)
    return myTree

#讀取數據文件並轉換爲列表(含有漢字的,使用CSV格式讀取容易出錯)
if __name__ == '__main__':
    myTree = main()
    print (myTree)

 結果居然是同樣的,深度懷疑作錯了。app

相關文章
相關標籤/搜索