決策樹算法以及matlab實現ID3算法

本文將詳細介紹ID3算法,其也是最經典的決策樹分類算法。node

一、ID3算法簡介及基本原理 
ID3算法基於信息熵來選擇最佳的測試屬性,它選擇當前樣本集中具備最大信息增益值的屬性做爲測試屬性;樣本集的劃分則依據測試屬性的取值進行,測試屬性有多少個不一樣的取值就將樣本集劃分爲多少個子樣本集,同時決策樹上相應於該樣本集的節點長出新的葉子節點。ID3算法根據信息論的理論,採用劃分後樣本集的不肯定性做爲衡量劃分好壞的標準,用信息增益值度量不肯定性:信息增益值越大,不肯定性越小。所以,ID3算法在每一個非葉節點選擇信息增益最大的屬性做爲測試屬性,這樣能夠獲得當前狀況下最純的劃分,從而獲得較小的決策樹。git

設S是s個數據樣本的集合。假定類別屬性具備m個不一樣的值:這裏寫圖片描述,設這裏寫圖片描述是類這裏寫圖片描述中的樣本數。對一個給定的樣本,它總的信息熵爲這裏寫圖片描述,其中,這裏寫圖片描述是任意樣本屬於這裏寫圖片描述的機率,通常能夠用這裏寫圖片描述估計。github

設一個屬性A具備k個不一樣的值這裏寫圖片描述,利用屬性A將集合S劃分爲k個子集這裏寫圖片描述,其中這裏寫圖片描述包含了集合S中屬性A取這裏寫圖片描述值的樣本。若選擇屬性A爲測試屬性,則這些子集就是從集合S的節點生長出來的新的葉節點。設這裏寫圖片描述是子集這裏寫圖片描述中類別爲這裏寫圖片描述的樣本數,則根據屬性A劃分樣本的信息熵爲這裏寫圖片描述 
其中,這裏寫圖片描述這裏寫圖片描述是子集這裏寫圖片描述中類別爲這裏寫圖片描述的樣本的機率。算法

最後,用屬性A劃分樣本集S後所得的信息增益(Gain)爲這裏寫圖片描述函數

顯然這裏寫圖片描述越小,Gain(A)的值就越大,說明選擇測試屬性A對於分類提供的信息越大,選擇A以後對分類的不肯定程度越小。屬性A的k個不一樣的值對應的樣本集S的k個子集或分支,經過遞歸調用上述過程(不包括已經選擇的屬性),生成其餘屬性做爲節點的子節點和分支來生成整個決策樹。ID3決策樹算法做爲一個典型的決策樹學習算法,其核心是在決策樹的各級節點上都用信息增益做爲判斷標準來進行屬性的選擇,使得在每一個非葉子節點上進行測試時,都能得到最大的類別分類增益,使分類後的數據集的熵最小。這樣的處理方法使得樹的平均深度較小,從而有效地提升了分類效率。學習

二、ID3算法的具體流程 
ID3算法的具體流程以下: 
1)對當前樣本集合,計算全部屬性的信息增益; 
2)選擇信息增益最大的屬性做爲測試屬性,把測試屬性取值相同的樣本劃爲同一個子樣本集; 
3)若子樣本集的類別屬性只含有單個屬性,則分支爲葉子節點,判斷其屬性值並標上相應的符號,而後返回調用處;不然對子樣本集遞歸調用本算法。測試

數據如圖所示編碼

序號 天氣 是否週末 是否有促銷 銷量 1 壞 是 是 高 2 壞 是 是 高 3 壞 是 是 高 4 壞 否 是 高 5 壞 是 是 高 6 壞 否 是 高 7 壞 是 否 高 8 好 是 是 高 9 好 是 否 高 10 好 是 是 高 11 好 是 是 高 12 好 是 是 高 13 好 是 是 高 14 壞 是 是 低 15 好 否 是 高 16 好 否 是 高 17 好 否 是 高 18 好 否 是 高 19 好 否 否 高 20 壞 否 否 低 21 壞 否 是 低 22 壞 否 是 低 23 壞 否 是 低 24 壞 否 否 低 25 壞 是 否 低 26 好 否 是 低 27 好 否 是 低 28 壞 否 否 低 29 壞 否 否 低 30 好 否 否 低 31 壞 是 否 低 32 好 否 是 低 33 好 否 否 低 34  好   否   否   低

採用ID3算法構建決策樹模型的具體步驟以下: 
1)根據公式這裏寫圖片描述,計算總的信息熵,其中數據中總記錄數爲34,而銷售數量爲「高」的數據有18,「低」的有16 
這裏寫圖片描述spa

2)根據公式這裏寫圖片描述這裏寫圖片描述,計算每一個測試屬性的信息熵。.net

對於天氣屬性,其屬性值有「好」和「壞」兩種。其中天氣爲「好」的條件下,銷售數量爲「高」的記錄爲11,銷售數量爲「低」的記錄爲6,可表示爲(11,6);天氣爲「壞」的條件下,銷售數量爲「高」的記錄爲7,銷售數量爲「低」的記錄爲10,可表示爲(7,10)。則天氣屬性的信息熵計算過程以下: 
這裏寫圖片描述 
這裏寫圖片描述 
這裏寫圖片描述

對因而否週末屬性,其屬性值有「是」和「否」兩種。其中是否週末屬性爲「是」的條件下,銷售數量爲「高」的記錄爲11,銷售數量爲「低」的記錄爲3,可表示爲(11,3);是否週末屬性爲「否」的條件下,銷售數量爲「高」的記錄爲7,銷售數量爲「低」的記錄爲13,可表示爲(7,13)。則節假日屬性的信息熵計算過程以下: 
這裏寫圖片描述 
這裏寫圖片描述 
這裏寫圖片描述

對因而否有促銷屬性,其屬性值有「是」和「否」兩種。其中是否有促銷屬性爲「是」的條件下,銷售數量爲「高」的記錄爲15,銷售數量爲「低」的記錄爲7,可表示爲(15,7);其中是否有促銷屬性爲「否」的條件下,銷售數量爲「高」的記錄爲3,銷售數量爲「低」的記錄爲9,可表示爲(3,9)。則是否有促銷屬性的信息熵計算過程以下: 
這裏寫圖片描述 
這裏寫圖片描述 
這裏寫圖片描述

根據公式這裏寫圖片描述,計算天氣、是否週末和是否有促銷屬性的信息增益值。 
這裏寫圖片描述 
這裏寫圖片描述 
這裏寫圖片描述

3)由計算結果能夠知道是否週末屬性的信息增益值最大,它的兩個屬性值「是」和「否」做爲該根節點的兩個分支。而後按照上面的步驟繼續對該根節點的兩個分支進行節點的劃分,針對每個分支節點繼續進行信息增益的計算,如此循環反覆,直到沒有新的節點分支,最終構成一棵決策樹。

因爲ID3決策樹算法採用了信息增益做爲選擇測試屬性的標準,會偏向於選擇取值較多的即所謂的高度分支屬性,而這類屬性並不必定是最優的屬性。同時ID3決策樹算法只能處理離散屬性,對於連續型的屬性,在分類前須要對其進行離散化。爲了解決傾向於選擇高度分支屬性的問題,人們採用信息增益率做爲選擇測試屬性的標準,這樣便獲得C4.5決策樹的算法。此外經常使用的決策樹算法還有CART算法、SLIQ算法、SPRINT算法和PUBLIC算法等等。

使用ID3算法創建決策樹的MATLAB代碼以下所示 
ID3_decision_tree.m

%% 使用ID3決策樹算法預測銷量高低 clear ; %% 數據預處理 disp('正在進行數據預處理...'); [matrix,attributes_label,attributes] = id3_preprocess(); %% 構造ID3決策樹,其中id3()爲自定義函數 disp('數據預處理完成,正在進行構造樹...'); tree = id3(matrix,attributes_label,attributes); %% 打印並畫決策樹 [nodeids,nodevalues] = print_tree(tree); tree_plot(nodeids,nodevalues); disp('ID3算法構建決策樹完成!');

id3_preprocess.m:

function [ matrix,attributes,activeAttributes ] = id3_preprocess( ) %% ID3算法數據預處理,把字符串轉換爲0,1編碼 % 輸出參數: % matrix: 轉換後的0,1矩陣; % attributes: 屬性和Label; % activeAttributes : 屬性向量,全1; %% 讀取數據 txt = {  '序號'    '天氣'    '是否週末'    '是否有促銷'    '銷量'
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            '' } attributes=txt(1,2:end); activeAttributes = ones(1,length(attributes)-1); data = txt(2:end,2:end); %% 針對每列數據進行轉換 [rows,cols] = size(data); matrix = zeros(rows,cols); for j=1:cols matrix(:,j) = cellfun(@trans2onezero,data(:,j)); end end function flag = trans2onezero(data) if strcmp(data,'') ||strcmp(data,'')... ||strcmp(data,'') flag =0; return ; end flag =1; end

id3.m:

function [ tree ] = id3( examples, attributes, activeAttributes ) %% ID3 算法 ,構建ID3決策樹 ...參考:https://github.com/gwheaton/ID3-Decision-Tree

% 輸入參數: % example: 輸入0、1矩陣; % attributes: 屬性值,含有Label; % activeAttributes: 活躍的屬性值;-1,1向量,1表示活躍; % 輸出參數: % tree:構建的決策樹; %% 提供的數據爲空,則報異常 if (isempty(examples)); error('必須提供數據!'); end % 常量 numberAttributes = length(activeAttributes); numberExamples = length(examples(:,1)); % 建立樹節點 tree = struct('value', 'null', 'left', 'null', 'right', 'null'); % 若是最後一列所有爲1,則返回「true」 lastColumnSum = sum(examples(:, numberAttributes + 1)); if (lastColumnSum == numberExamples); tree.value = 'true'; return end % 若是最後一列所有爲0,則返回「falseif (lastColumnSum == 0); tree.value = 'false'; return end % 若是活躍的屬性爲空,則返回label最多的屬性值 if (sum(activeAttributes) == 0); if (lastColumnSum >= numberExamples / 2); tree.value = 'true'; else tree.value = 'false'; end return end %% 計算當前屬性的熵 p1 = lastColumnSum / numberExamples; if (p1 == 0); p1_eq = 0; else p1_eq = -1*p1*log2(p1); end p0 = (numberExamples - lastColumnSum) / numberExamples; if (p0 == 0); p0_eq = 0; else p0_eq = -1*p0*log2(p0); end currentEntropy = p1_eq + p0_eq; %% 尋找最大增益 gains = -1*ones(1,numberAttributes); % 初始化增益 for i=1:numberAttributes; if (activeAttributes(i)) % 該屬性仍處於活躍狀態,對其更新 s0 = 0; s0_and_true = 0; s1 = 0; s1_and_true = 0; for j=1:numberExamples; if (examples(j,i)); s1 = s1 + 1; if (examples(j, numberAttributes + 1)); s1_and_true = s1_and_true + 1; end else s0 = s0 + 1; if (examples(j, numberAttributes + 1)); s0_and_true = s0_and_true + 1; end end end % 熵 S(v=1) if (~s1); p1 = 0; else p1 = (s1_and_true / s1); end if (p1 == 0); p1_eq = 0; else p1_eq = -1*(p1)*log2(p1); end if (~s1); p0 = 0; else p0 = ((s1 - s1_and_true) / s1); end if (p0 == 0); p0_eq = 0; else p0_eq = -1*(p0)*log2(p0); end entropy_s1 = p1_eq + p0_eq; % 熵 S(v=0) if (~s0); p1 = 0; else p1 = (s0_and_true / s0); end if (p1 == 0); p1_eq = 0; else p1_eq = -1*(p1)*log2(p1); end if (~s0); p0 = 0; else p0 = ((s0 - s0_and_true) / s0); end if (p0 == 0); p0_eq = 0; else p0_eq = -1*(p0)*log2(p0); end entropy_s0 = p1_eq + p0_eq; gains(i) = currentEntropy - ((s1/numberExamples)*entropy_s1) - ((s0/numberExamples)*entropy_s0); end end % 選出最大增益 [~, bestAttribute] = max(gains); % 設置相應值 tree.value = attributes{bestAttribute}; % 去活躍狀態 activeAttributes(bestAttribute) = 0; % 根據bestAttribute把數據進行分組 examples_0= examples(examples(:,bestAttribute)==0,:); examples_1= examples(examples(:,bestAttribute)==1,:); % 當 value = false or 0, 左分支 if (isempty(examples_0)); leaf = struct('value', 'null', 'left', 'null', 'right', 'null'); if (lastColumnSum >= numberExamples / 2); % for matrix examples leaf.value = 'true'; else leaf.value = 'false'; end tree.left = leaf; else
    % 遞歸 tree.left = id3(examples_0, attributes, activeAttributes); end % 當 value = true or 1, 右分支 if (isempty(examples_1)); leaf = struct('value', 'null', 'left', 'null', 'right', 'null'); if (lastColumnSum >= numberExamples / 2); leaf.value = 'true'; else leaf.value = 'false'; end tree.right = leaf; else
    % 遞歸 tree.right = id3(examples_1, attributes, activeAttributes); end % 返回 return end

print_tree.m:

function [nodeids_,nodevalue_] = print_tree(tree) %% 打印樹,返回樹的關係向量 global nodeid nodeids nodevalue; nodeids(1)=0; % 根節點的值爲0 nodeid=0; nodevalue={}; if isempty(tree) disp('空樹!'); return ; end queue = queue_push([],tree); while ~isempty(queue) % 隊列不爲空 [node,queue] = queue_pop(queue); % 出隊列 visit(node,queue_curr_size(queue)); if ~strcmp(node.left,'null') % 左子樹不爲空 queue = queue_push(queue,node.left); % 進隊 end if ~strcmp(node.right,'null') % 左子樹不爲空 queue = queue_push(queue,node.right); % 進隊 end end %% 返回 節點關係,用於treeplot畫圖 nodeids_=nodeids; nodevalue_=nodevalue; end function visit(node,length_) %% 訪問node 節點,並把其設置值爲nodeid的節點 global nodeid nodeids nodevalue; if isleaf(node) nodeid=nodeid+1; fprintf('葉子節點,node: %d\t,屬性值: %s\n', ... nodeid, node.value); nodevalue{1,nodeid}=node.value; else % 要麼是葉子節點,要麼不是 %if isleaf(node.left) && ~isleaf(node.right) % 左邊爲葉子節點,右邊不是 nodeid=nodeid+1; nodeids(nodeid+length_+1)=nodeid; nodeids(nodeid+length_+2)=nodeid; fprintf('node: %d\t屬性值: %s\t,左子樹爲節點:node%d,右子樹爲節點:node%d\n', ... nodeid, node.value,nodeid+length_+1,nodeid+length_+2); nodevalue{1,nodeid}=node.value; end end function flag = isleaf(node) %% 是不是葉子節點 if strcmp(node.left,'null') && strcmp(node.right,'null') % 左右都爲空 flag =1; else flag=0; end end

tree_plot.m

function tree_plot( p ,nodevalues) %% 參考treeplot函數 [x,y,h]=treelayout(p); f = find(p~=0); pp = p(f); X = [x(f); x(pp); NaN(size(f))]; Y = [y(f); y(pp); NaN(size(f))]; X = X(:); Y = Y(:); n = length(p); if n < 500, hold on ; plot (x, y, 'ro', X, Y, 'r-'); nodesize = length(x); for i=1:nodesize %            text(x(i)+0.01,y(i),['node' num2str(i)]); text(x(i)+0.01,y(i),nodevalues{1,i}); end hold off; else plot (X, Y, 'r-'); end; xlabel(['height = ' int2str(h)]); axis([0 1 0 1]); end

queue_push.m

function [ newqueue ] = queue_push( queue,item ) %% 進隊 % cols = size(queue); % newqueue =structs(1,cols+1); newqueue=[queue,item]; end

queue_pop.m

function [ item,newqueue ] = queue_pop( queue ) %% 訪問隊列 if isempty(queue) disp('隊列爲空,不能訪問!'); return; end item = queue(1); % 第一個元素彈出 newqueue=queue(2:end); % 日後移動一個元素位置 end

queue_curr_size.m

function [ length_ ] = queue_curr_size( queue ) %% 當前隊列長度 length_= length(queue); end

轉載自https://blog.csdn.net/lfdanding/article/details/50753239

相關文章
相關標籤/搜索