上一篇文章介紹了決策樹的剪枝概念和意義以及幾種常見的剪枝策略。因爲剪枝策略或方法能夠很是多,並且每一種在不一樣的應用場景下各有優劣,沒有絕對的好。本篇文章繼續討論決策樹的剪枝。node
如今咱們已經知道剪枝須要一個判斷依據來決定對當前節點是否須要剪枝,能夠定義一個損失函數(loss function)或者代價函數(cost function)來實現。假設樹T的葉節點個數爲|T|,t是某一葉節點,該節點覆蓋Nt個樣本,其中分類爲k的樣本點Ntk個,Ht(T)爲葉節點t上的經驗熵,這裏不妨再囉嗦幾句,根據前面有關決策樹生成的介紹可知,信息熵是表徵系統的混亂程度,熵越大越混亂,也就是越難判斷樣本分類。定義損失函數爲,算法
(1)函數
其中經驗熵爲,post
(2)學習
若是Ntk爲0,則跳過這個分類。 spa
將(1)中右端第一部分記爲,code
則(1)變成對象
(3)blog
(3)式中,C(T)表示模型對訓練數據的預測偏差,即,偏差使用混亂程度來表徵,|T|表示模型複雜度,上一篇文章中講到下降樹的複雜度也是剪枝的緣由,參數α 爲控制因子,α 較小時,能夠容許必定程度複雜的樹,α 較大時,促使選擇簡單的樹,不然損失函數會很大,α=1時,第二項就是模型的複雜度——葉節點個數。遞歸
能夠看出,爲了下降損失函數,要求咱們儘可能下降模型的複雜度和系統的信息熵。前面講決策樹生成的時候,考慮了信息增益(比)來對訓練數據進行擬合,這裏損失函數考慮了減少模型複雜度,決策樹生成學習局部的模型,決策樹剪枝學習總體的模型。
輸入:決策樹T,參數α
輸出:剪枝後的樹Tα
步驟:
統計學習方法,李航
代碼片斷以下:
爲了剪枝判斷方便些,在節點類裏面增長了幾個輔助字段
public class Node { /// <summary> /// 節點惟一id /// </summary> public int id; /// <summary> /// 用於劃分的屬性名,葉節點爲null /// </summary> public string Attr { get; set; } /// <summary> /// 節點分類,只有葉節點有分類值,內部節點爲null /// </summary> public string Class { get; set; } /// <summary> /// 根據屬性的取值劃分子空間,葉節點爲null /// key爲屬性值,value爲對應的子樹的根結點,表示子空間 /// </summary> public Dictionary<string, Node> Children { get; set; } /// <summary> /// 父節點,根節點的父節點爲null /// </summary> public Node parent { get; set; } /// <summary> /// 對應父節點中Children的key值,父節點的劃分屬性對應的值 /// </summary> public string attrVal; /// <summary> /// 深度,根節點深度爲0 /// </summary> public int deep; /// <summary> /// 每一個分類的樣本數量 /// </summary> public double[] classCount; /// <summary> /// 節點覆蓋的總樣本數 = classCount.Sum() /// </summary> public double count; }
決策樹也增長了幾個字段
public class DTree { /// <summary> /// 全部的分類值 /// </summary> private string[] _classes; private int _maxDeep; /// <summary> /// 最大深度,根節點深度爲0 /// </summary> public int MaxDeep { get { return _maxDeep; } } ... // 其餘字段和成員方法 }
決策樹的構造就不給出來了,主要是生成時注意節點對象所覆蓋的樣本點數量,樣本各分類數量,以及節點id等。
而後決策樹中剪枝的方法以下
public class DTree {
... // 其餘字段和成員函數 /// <summary> /// 剪枝 /// </summary> public void Prune() { var tuple = GetPrecNodes(_maxDeep); var leaves = GetInitLeaves(); var deep = _maxDeep; // 遞歸深度 var unPrunedCount = 0; // 某輪未被剪枝的數量 while(deep > 0) { var nodes = GetPrecNodes(deep); foreach (var node in nodes) { // 考察內部節點 if (node.Children != null && node.Children.Count > 0) { // 判斷是否須要剪枝 var preLoss = GetLoss(leaves); var fakeLeaves = GetPrunedLeaves(leaves, node); var postLoss = GetLoss(fakeLeaves); if (postLoss < preLoss) { // 須要剪枝,則進行剪枝 node.parent.Children[node.attrVal] = fakeLeaves[fakeLeaves.Count - 1]; leaves = fakeLeaves; // 更新葉節點 } else { unPrunedCount++; } } } if(deep == _maxDeep) // 當前深度與最大深度保持同步,則須要檢查是否須要修改最大深度 { if(unPrunedCount == 0) // 本輪被考察節點所有被剪枝,則修改最大深度 { _maxDeep--; } } deep--; } } /// <summary> /// 獲取剪枝後的葉節點列表 /// </summary> /// <param name="leaves">剪枝前葉節點列表</param> /// <param name="node">被剪枝的節點</param> /// <returns></returns> private List<Node> GetPrunedLeaves(List<Node> leaves, Node node) { var dict = node.Children.ToDictionary(c => c.Value.id, c => c.Value); var list = leaves.Where(l => !dict.ContainsKey(l.id)).ToList(); // 添加剪枝後的新葉節點 var leaf = new Node() { id = node.id }; leaf.parent = node.parent; leaf.deep = node.deep; leaf.Attr = node.Attr; leaf.count = node.count; leaf.classCount = node.classCount; int maxIdx = 0; double maxCount = node.classCount[0]; for(int i = 0; i < node.classCount.Length; i++) { if(maxCount < node.classCount[i]) { maxIdx = i; maxCount = node.classCount[i]; } } leaf.Class = _classes[maxIdx]; list.Add(leaf); return list; } /// <summary> /// 獲取損失函數 /// </summary> /// <param name="leaves"></param> /// <returns></returns> private double GetLoss(List<Node> leaves, double alpha = 1) { double sum = 0; foreach(var leaf in leaves) { double entropy = 0; foreach(var c in leaf.classCount) { entropy -= c / leaf.count * Math.Log(c / leaf.count, 2); } sum += entropy * leaf.count; } return sum + leaves.Count * alpha; } /// <summary> /// 獲取指定深度的前驅節點列表,即,節點深度爲指定深度減1的節點列表 /// </summary> /// <returns></returns> private List<Node> GetPrecNodes(int deep) { var list = new List<Node>(); // 結果列表 // var dest = deep - 1; // bfs 遍歷便可 var queue = new Queue<Node>(); queue.Enqueue(_root); while(queue.Count > 0) { var node = queue.Dequeue(); if (node.deep == dest) list.Add(node); else if(node.deep < dest) { if (node.Children != null) { foreach (var n in node.Children) { queue.Enqueue(n.Value); } } } //if (node.Children == null || node.Children.Count == 0) // leaves.Add(node); } return list; } /// <summary> /// 獲取初始的葉節點列表 /// </summary> /// <returns></returns> private List<Node> GetInitLeaves() { // bfs 遍歷便可 var queue = new Queue<Node>(); queue.Enqueue(_root); var leaves = new List<Node>(); while (queue.Count > 0) { var node = queue.Dequeue(); if (node.Children == null || node.Children.Count == 0) leaves.Add(node); } return leaves; } }
(代碼僅幫助理解剪枝策略,不保證能正確運行)