Ranklib源碼剖析--LambdaMart

Ranklib是一套優秀的Learning to Rank領域的開源實現,其中有實現了MART,RankNet,RankBoost,LambdaMart,Random Forest等模型。其中由微軟發佈的LambdaMART是IR業內經常使用的Learning to Rank模型,本文主要介紹Ranklib中的LambdaMART模型的具體實現,用以幫助理解paper中闡述的方法。本文是基於version2.3版本的Ranklib來介紹的。html

LambdaMart的基本原理詳見以前的博客:http://www.cnblogs.com/bentuwuying/p/6690836.html。要知道LambdaMart是基於MART的,而MART又是由若干棵regression tree組合而成的。因此,咱們先來看看Ranklib中是如何實現regression tree的,以及在給定training data with labels的狀況下,regression tree是如何擬合的。java

1. regression tree

regression tree擬合給定training data的步驟總結歸納以下:node

RegressionTree
    nodes #限制一棵樹的最大葉子節點數
    minLeafSupport #控制分裂的次數,若是某個節點所包含的訓練數據小於2*minLeafSupport ,則該節點再也不分裂
    root #根節點
    leaves #葉子節點list
    構造函數RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport)
        對各個類變量進行初始化
    fit #對training data進行擬合regression tree
        新建一個隊列queue,用於按隊列順序(即按層遍歷的順序)進行分裂
        初始化一個regression tree的根節點root
        root.split #根節點分裂
            hist.findBestSplit #調用Split對象包含的FeatureHistogram對象的分裂方法(在該節點的已經統計好的特徵統計直方圖的基礎上,尋找最佳分裂點,進行分裂,再計算左右子節點的特徵統計直方圖,並對左右子節點進行初始化)
                判斷deviance,爲0則分裂不成功
                根據samplingRate決定usedFeatures(分裂時須要使用的features的索引)
                調用內部的findBestSplit方法
                    在一個節點上,在usedFeatures中,根據該節點的特徵統計直方圖,來進行分裂時feature和threshold的選擇
                    S = sumLeft * sumLeft / countLeft + sumRight * sumRight / countRight
                    對每一個可選的劃分點(feature和threshold組合),求最大的S值,對應於均方偏差最小,是最優的劃分點
                判斷劃分是否成功,若S=-1,則分裂不成功
                對該節點上的每一個訓練數據,根據最優分裂點,進行左右子節點的分配
                初始化分裂後左右子節點各自的特徵統計直方圖
                    construct #通常用做父節點分裂後產生的左子節點的特徵統計直方圖的構造函數(當使用父節點來構造時,thresholds數組不變,可是sum和count數組須要從新構造)
                    construct #通常用做父節點分裂後產生的右子節點的特徵統計直方圖的構造函數
                計算本節點和左右子節點的均方偏差
                sp.set #調用FeatureHistogram對象所在的Split對象的方法
                    通常在該節點進行分裂完成後,設定分裂時的featureID,threshold,deviance
                    只有非葉子節點纔會進行分裂(調用這個方法),因此只有非葉子節點的featureID不爲-1,葉子節點因爲沒有調用這個方法,故featureID=-1
                初始化左子節點(根據分裂到左子節點的訓練數據索引數組,左子節點的特徵統計直方圖,左子節點的均方偏差,左子節點的訓練數據label之和),並設置到當前節點的左子節點變量上
                初始化右子節點(根據分裂到右子節點的訓練數據索引數組,右子節點的特徵統計直方圖,右子節點的均方偏差,右子節點的訓練數據label之和),並設置到當前節點的右子節點變量上
        insert #將左右的子節點插入隊列,用於下面遍歷
            按均方偏差從大到小的順序進行插入隊列
        循環:按隊列順序(即按層遍歷的順序)進行分裂,再將每次可以成功分裂的產生的兩個子節點插入隊列中
        根據根節點root的leaves類方法(迭代遍歷),設置regression tree的leaves類變量

 

下面是regression tree擬合過程當中涉及到的幾個類文件代碼,關鍵部分都有添加了詳細的註釋。api

 

1. FeatureHistogram數組

  1 package ciir.umass.edu.learning.tree;
  2 import java.util.ArrayList;
  3 import java.util.Arrays;
  4 import java.util.List;
  5 import java.util.Random;
  6 import ciir.umass.edu.learning.DataPoint;
  7 import ciir.umass.edu.utilities.MyThreadPool;
  8 import ciir.umass.edu.utilities.WorkerThread;
  9 /**
 10  * @author vdang
 11  */
 12 //特徵直方圖類,對RankList對象進行特徵的直方圖統計,選擇每次split時最優的feature和劃分點
 13 public class FeatureHistogram {
 14     // 存放分裂時的featureIdx,thresholdIdx,以及評判是否最佳分裂的評分值sumLeft*sumLeft/countLeft + sumRight*sumRight/countRight
 15     class Config {
 16         int featureIdx = -1;
 17         int thresholdIdx = -1;
 18         double S = -1;
 19     }
 20     
 21     //Parameter
 22     public static float samplingRate = 1; //採樣率,用於對分裂時使用的feature個數進行採樣,不使用全部的feature
 23     
 24     //Variables
 25     public int[] features = null; //feature數組,每一個元素是一個feature id(fid)
 26     public float[][] thresholds = null; //二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是閾值,個數爲全部訓練數據在此feature上的value的去重個數,從小到大排序的不重複值,用於對此節點的訓練數據在此feature上分裂時可選的feature value閾值
 27     public double[][] sum = null; //二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是label之和,是全部訓練數據中在此feature上的value小於等於相應位置的threshold值(thresholds[i][j])的DataPoint的label之和,sum二維數組大小與thresholds數組相同
 28     public double sumResponse = 0; //全部的訓練數據的label之和
 29     public double sqSumResponse = 0; //全部的訓練數據的label的平方和
 30     public int[][] count = null; //二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是個數,是全部訓練數據中在此feature上的value小於等於相應位置的threshold值(thresholds[i][j])的DataPoint的個數,count二維數組大小與thresholds數組相同
 31     public int[][] sampleToThresholdMap = null; //二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是索引,是對應訓練數據samples[i][j]在特定feature上每一個訓練數據的value對應於其在thresholds數組中相應行的列索引位置
 32     
 33     //whether to re-use its parents @sum and @count instead of cleaning up the parent and re-allocate for the children.
 34     //@sum and @count of any intermediate tree node (except for root) can be re-used.  
 35     private boolean reuseParent = false;
 36     
 37     public FeatureHistogram()
 38     {
 39         
 40     }
 41 
 42     //FeatureHistogram構造函數(1-1),通常用做整棵樹/根節點的feature histogram,計算該節點的特徵統計直方圖
 43     //@samples: 訓練數據
 44     //@labels: 訓練數據的label
 45     //@sampleSortedIdx: 將樣本根據特徵排序,方便作樹的分列時快速找出最優分列點,sorted list of samples by each feature, need initializing only once,初始化可見LambdaMART.java中的init()
 46     //@features: 訓練數據的特徵集合
 47     //@thresholds: 建立存放候選閾值(分列點)的表,a table of candidate thresolds for each feature, we will select the best tree split from these candidates later on
 48 ,初始化可見LambdaMART.java中的init(),此二維數組的每一行的最後一列的值是後加的,爲Float.MAX_VALUE
 49     public void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, int[] features, float[][] thresholds)
 50     {
 51         this.features = features;
 52         this.thresholds = thresholds;
 53         
 54         sumResponse = 0;
 55         sqSumResponse = 0;
 56         
 57         sum = new double[features.length][];
 58         count = new int[features.length][];
 59         sampleToThresholdMap = new int[features.length][];
 60         
 61         //肯定是否使用多線程計算
 62         MyThreadPool p = MyThreadPool.getInstance();
 63         if(p.size() == 1)
 64             construct(samples, labels, sampleSortedIdx, thresholds, 0, features.length-1);
 65         else
 66             p.execute(new Worker(this, samples, labels, sampleSortedIdx, thresholds), features.length);            
 67     }
 68     //FeatureHistogram構造函數(1-2),被(1-1)調用
 69     protected void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, float[][] thresholds, int start, int end)
 70     {
 71         for(int i=start;i<=end;i++) //對於每一個feature
 72         {
 73             int fid = features[i]; // 獲取feature id
 74             //get the list of samples associated with this node (sorted in ascending order with respect to the current feature)
 75             int[] idx = sampleSortedIdx[i]; //根據此feature下的value從小到大排序後的訓練數據的索引數組
 76             
 77             double sumLeft = 0; //累計此值,用於給sumLabel使用
 78             float[] threshold = thresholds[i];
 79             double[] sumLabel = new double[threshold.length]; //對應前面sum二維數組的一行
 80             int[] c = new int[threshold.length]; //對應前面count二維數組的一行
 81             int[] stMap = new int[samples.length]; //對應前面sampleToThresholdMap二維數組的一行
 82             
 83             int last = -1;
 84             for(int t=0;t<threshold.length;t++) //對於每一個可選的split閾值
 85             {
 86                 int j=last+1;
 87                 //find the first sample that exceeds the current threshold
 88                 for(;j<idx.length;j++)
 89                 {
 90                     int k = idx[j]; //獲取此DataPoint在samples數組中的索引
 91                     if(samples[k].getFeatureValue(fid) >  threshold[t])
 92                         break;
 93                     sumLeft += labels[k];
 94                     if(i == 0)
 95                     {
 96                         sumResponse += labels[k];
 97                         sqSumResponse += labels[k] * labels[k];
 98                     }
 99                     stMap[k] =  t;
100                 }
101                 last = j-1;    
102                 sumLabel[t] = sumLeft;
103                 c[t] = last+1;
104             }
105             sampleToThresholdMap[i] = stMap;
106             sum[i] = sumLabel;
107             count[i] = c;
108         }
109     }
110     
111     //update(1-1), update the histogram with these training labels (the feature histogram will be used to find the best tree split)
112     protected void update(double[] labels)
113     {
114         sumResponse = 0;
115         sqSumResponse = 0;
116         
117         
118         //肯定是否使用多線程計算
119         MyThreadPool p = MyThreadPool.getInstance();
120         if(p.size() == 1)
121             update(labels, 0, features.length-1);
122         else
123             p.execute(new Worker(this, labels), features.length);
124     }
125 
126     //update(1-2),被(1-1)調用
127     protected void update(double[] labels, int start, int end)
128     {
129         for(int f=start;f<=end;f++)
130             Arrays.fill(sum[f], 0);
131         for(int k=0;k<labels.length;k++)
132         {
133             for(int f=start;f<=end;f++)
134             {
135                 int t = sampleToThresholdMap[f][k];
136                 sum[f][t] += labels[k];
137                 if(f == 0)
138                 {
139                     sumResponse += labels[k];
140                     sqSumResponse += labels[k]*labels[k];
141                 }
142                 //count doesn't change, so no need to re-compute
143             }
144         }
145         for(int f=start;f<=end;f++)
146         {            
147             for(int t=1;t<thresholds[f].length;t++)
148                 sum[f][t] += sum[f][t-1];
149         }
150     }
151     
152     //FeatureHistogram構造函數(2-1),通常用做父節點分裂後產生的左子節點的特徵統計直方圖的構造函數
153     //當使用父節點來構造時,thresholds數組不變,可是sum和count數組須要從新構造
154     //@soi: 使用的訓練數據的索引位置
155     public void construct(FeatureHistogram parent, int[] soi, double[] labels)
156     {
157         this.features = parent.features;
158         this.thresholds = parent.thresholds;
159         sumResponse = 0;
160         sqSumResponse = 0;
161         sum = new double[features.length][];
162         count = new int[features.length][];
163         sampleToThresholdMap = parent.sampleToThresholdMap;
164         
165         
166         //肯定是否使用多線程計算
167         MyThreadPool p = MyThreadPool.getInstance();
168         if(p.size() == 1)
169             construct(parent, soi, labels, 0, features.length-1);
170         else
171             p.execute(new Worker(this, parent, soi, labels), features.length);    
172     }
173 
174     //FeatureHistogram構造函數(2-2),被(2-1)調用
175     protected void construct(FeatureHistogram parent, int[] soi, double[] labels, int start, int end)
176     {
177         //init
178         for(int i=start;i<=end;i++)
179         {            
180             float[] threshold = thresholds[i];
181             sum[i] = new double[threshold.length];
182             count[i] = new int[threshold.length];
183             Arrays.fill(sum[i], 0);
184             Arrays.fill(count[i], 0);
185         }
186         
187         //update
188         for(int i=0;i<soi.length;i++)
189         {
190             int k = soi[i];
191             for(int f=start;f<=end;f++)
192             {
193                 int t = sampleToThresholdMap[f][k];
194                 sum[f][t] += labels[k];
195                 count[f][t] ++;
196                 if(f == 0)
197                 {
198                     sumResponse += labels[k];
199                     sqSumResponse += labels[k]*labels[k];
200                 }
201             }
202         }
203         
204         for(int f=start;f<=end;f++)
205         {            
206             for(int t=1;t<thresholds[f].length;t++)
207             {
208                 sum[f][t] += sum[f][t-1];
209                 count[f][t] += count[f][t-1];
210             }
211         }
212     }
213     
214     //FeatureHistogram構造函數(3-1),通常用做父節點分裂後產生的右子節點的特徵統計直方圖的構造函數
215     public void construct(FeatureHistogram parent, FeatureHistogram leftSibling, boolean reuseParent)
216     {
217         this.reuseParent = reuseParent;
218         this.features = parent.features;
219         this.thresholds = parent.thresholds;
220         sumResponse = parent.sumResponse - leftSibling.sumResponse;
221         sqSumResponse = parent.sqSumResponse - leftSibling.sqSumResponse;
222         
223         if(reuseParent)
224         {
225             sum = parent.sum;
226             count = parent.count;
227         }
228         else
229         {
230             sum = new double[features.length][];
231             count = new int[features.length][];
232         }
233         sampleToThresholdMap = parent.sampleToThresholdMap;
234 
235         //肯定是否使用多線程計算
236         MyThreadPool p = MyThreadPool.getInstance();
237         if(p.size() == 1)
238             construct(parent, leftSibling, 0, features.length-1);
239         else
240             p.execute(new Worker(this, parent, leftSibling), features.length);
241     }
242 
243     //FeatureHistogram構造函數(3-2),被(3-1)調用
244     protected void construct(FeatureHistogram parent, FeatureHistogram leftSibling, int start, int end)
245     {
246         for(int f=start;f<=end;f++)
247         {
248             float[] threshold = thresholds[f];
249             if(!reuseParent)
250             {
251                 sum[f] = new double[threshold.length];
252                 count[f] = new int[threshold.length];
253             }
254             for(int t=0;t<threshold.length;t++)
255             {
256                 sum[f][t] = parent.sum[f][t] - leftSibling.sum[f][t];
257                 count[f][t] = parent.count[f][t] - leftSibling.count[f][t];
258             }
259         }
260     }
261     
262     //findBestSplit函數(1-2),被(1-1)調用。在一個節點上,在usedFeatures中,根據該節點的特徵統計直方圖,來進行分裂時feature和threshold的選擇
263     protected Config findBestSplit(int[] usedFeatures, int minLeafSupport, int start, int end)
264     {
265         Config cfg = new Config();
266         int totalCount = count[start][count[start].length-1];
267         for(int f=start;f<=end;f++)
268         {
269             int i = usedFeatures[f];
270             float[] threshold = thresholds[i];
271             
272             for(int t=0;t<threshold.length;t++)
273             {
274                 int countLeft = count[i][t];
275                 int countRight = totalCount - countLeft;
276                 if(countLeft < minLeafSupport || countRight < minLeafSupport)
277                     continue;
278                 
279                 double sumLeft = sum[i][t];
280                 double sumRight = sumResponse - sumLeft;
281                 
282                 double S = sumLeft * sumLeft / countLeft + sumRight * sumRight / countRight;
283                 //求最大的S值,對應於均方偏差最小,是最優的劃分點
284                 if(cfg.S < S)
285                 {
286                     cfg.S = S;
287                     cfg.featureIdx = i;
288                     cfg.thresholdIdx = t;
289                 }
290             }
291         }        
292         return cfg;
293     }
294     
295     //findBestSplit函數(1-1),在該節點的已經統計好的特徵統計直方圖的基礎上,尋找最佳分裂點,進行分裂,再計算左右子節點的特徵統計直方圖,並對左右子節點進行初始化
296     public boolean findBestSplit(Split sp, double[] labels, int minLeafSupport)
297     {
298         if(sp.getDeviance() >= 0.0 && sp.getDeviance() <= 0.0)//equals 0
299             return false;//no need to split
300         
301         int[] usedFeatures = null;//index of the features to be used for tree splitting
302         if(samplingRate < 1)//need to do sub sampling (feature sampling)
303         {
304             int size = (int)(samplingRate * features.length);
305             usedFeatures = new int[size];
306             //put all features into a pool
307             List<Integer> fpool = new ArrayList<Integer>();
308             for(int i=0;i<features.length;i++)
309                 fpool.add(i);
310             //do sampling, without replacement
311             Random r = new Random();
312             for(int i=0;i<size;i++)
313             {
314                 int sel = r.nextInt(fpool.size());
315                 usedFeatures[i] = fpool.get(sel);
316                 fpool.remove(sel);
317             }
318         }
319         else//no sub-sampling, all features will be used
320         {
321             usedFeatures = new int[features.length];
322             for(int i=0;i<features.length;i++)
323                 usedFeatures[i] = i;
324         }
325         
326         //find the best split
327         Config best = new Config();
328         //肯定是否使用多線程
329         MyThreadPool p = MyThreadPool.getInstance();
330         if(p.size() == 1)
331             best = findBestSplit(usedFeatures, minLeafSupport, 0, usedFeatures.length-1);
332         else
333         {
334             WorkerThread[] workers = p.execute(new Worker(this, usedFeatures, minLeafSupport), usedFeatures.length);
335             for(int i=0;i<workers.length;i++)
336             {
337                 Worker wk = (Worker)workers[i];
338                 if(best.S < wk.cfg.S)
339                     best = wk.cfg;
340             }        
341         }
342         
343         if(best.S == -1)//unsplitable, for some reason...
344             return false;
345         
346         //if(minS >= sp.getDeviance())
347             //return null;
348         
349         double[] sumLabel = sum[best.featureIdx];
350         int[] sampleCount = count[best.featureIdx];
351         
352         double s = sumLabel[sumLabel.length-1];
353         int c = sampleCount[sumLabel.length-1];
354         
355         double sumLeft = sumLabel[best.thresholdIdx];
356         int countLeft = sampleCount[best.thresholdIdx];
357         
358         double sumRight = s - sumLeft;
359         int countRight = c - countLeft;
360         
361         int[] left = new int[countLeft];
362         int[] right = new int[countRight];
363         int l = 0;
364         int r = 0;
365         int k = 0;
366         int[] idx = sp.getSamples();
367         //對該節點上的每一個訓練數據,根據最優分裂點,進行左右子節點的分配
368         for(int j=0;j<idx.length;j++)
369         {
370             k = idx[j];
371             if(sampleToThresholdMap[best.featureIdx][k] <= best.thresholdIdx)//go to the left
372                 left[l++] = k;
373             else//go to the right
374                 right[r++] = k;
375         }
376         
377         //初始化分裂後左右子節點各自的特徵統計直方圖
378         FeatureHistogram lh = new FeatureHistogram();
379         lh.construct(sp.hist, left, labels); //初始化左子節點的特徵統計直方圖
380         FeatureHistogram rh = new FeatureHistogram();
381         rh.construct(sp.hist, lh, !sp.isRoot()); //初始化右子節點的特徵統計直方圖
382         double var = sqSumResponse - sumResponse * sumResponse / idx.length; //計算本節點的均方偏差
383         double varLeft = lh.sqSumResponse - lh.sumResponse * lh.sumResponse / left.length; //計算左子節點的均方偏差
384         double varRight = rh.sqSumResponse - rh.sumResponse * rh.sumResponse / right.length; //計算右子節點的均方偏差
385         
386         sp.set(features[best.featureIdx], thresholds[best.featureIdx][best.thresholdIdx], var);
387         sp.setLeft(new Split(left, lh, varLeft, sumLeft));
388         sp.setRight(new Split(right, rh, varRight, sumRight));
389         
390         sp.clearSamples(); //清理本節點所屬的sortedSampleIDs,samples,hist等數據
391         
392         return true;
393     }    
394     class Worker extends WorkerThread {
395         FeatureHistogram fh = null;
396         int type = -1;
397         
398         //find best split (type == 0)
399         int[] usedFeatures = null;
400         int minLeafSup = -1;
401         Config cfg = null;
402         
403         //update (type = 1)
404         double[] labels = null;
405         
406         //construct (type = 2)
407         FeatureHistogram parent = null;
408         int[] soi = null;
409         
410         //construct (type = 3)
411         FeatureHistogram leftSibling = null;
412         
413         //construct (type = 4)
414         DataPoint[] samples;
415         int[][] sampleSortedIdx;
416         float[][] thresholds;
417         
418         public Worker()
419         {
420         }
421         public Worker(FeatureHistogram fh, int[] usedFeatures, int minLeafSup)
422         {
423             type = 0;
424             this.fh = fh;
425             this.usedFeatures = usedFeatures;
426             this.minLeafSup = minLeafSup;
427         }
428         public Worker(FeatureHistogram fh, double[] labels)
429         {
430             type = 1;
431             this.fh = fh;
432             this.labels = labels;
433         }
434         public Worker(FeatureHistogram fh, FeatureHistogram parent, int[] soi, double[] labels)
435         {
436             type = 2;
437             this.fh = fh;
438             this.parent = parent;
439             this.soi = soi;
440             this.labels = labels;
441         }
442         public Worker(FeatureHistogram fh, FeatureHistogram parent, FeatureHistogram leftSibling)
443         {
444             type = 3;
445             this.fh = fh;
446             this.parent = parent;
447             this.leftSibling = leftSibling;
448         }
449         public Worker(FeatureHistogram fh, DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, float[][] thresholds)
450         {
451             type = 4;
452             this.fh = fh;
453             this.samples = samples;
454             this.labels = labels;
455             this.sampleSortedIdx = sampleSortedIdx;
456             this.thresholds = thresholds;            
457         }
458         public void run()
459         {
460             if(type == 0)
461                 cfg = fh.findBestSplit(usedFeatures, minLeafSup, start, end);
462             else if(type == 1)
463                 fh.update(labels, start, end);
464             else if(type == 2)
465                 fh.construct(parent, soi, labels, start, end);
466             else if(type == 3)
467                 fh.construct(parent, leftSibling, start, end);
468             else if(type == 4)
469                 fh.construct(samples, labels, sampleSortedIdx, thresholds, start, end);
470         }        
471         public WorkerThread clone()
472         {
473             Worker wk = new Worker();
474             wk.fh = fh;
475             wk.type = type;
476             
477             //find best split (type == 0)
478             wk.usedFeatures = usedFeatures;
479             wk.minLeafSup = minLeafSup;
480             //wk.cfg = cfg;
481             
482             //update (type = 1)
483             wk.labels = labels;
484             
485             //construct (type = 2)
486             wk.parent = parent;
487             wk.soi = soi;
488             
489             //construct (type = 3)
490             wk.leftSibling = leftSibling;
491             
492             //construct (type = 1)
493             wk.samples = samples;
494             wk.sampleSortedIdx = sampleSortedIdx;
495             wk.thresholds = thresholds;            
496             
497             return wk;
498         }
499     }
500 }

 

2. Split多線程

  1 package ciir.umass.edu.learning.tree;
  2 import java.util.ArrayList;
  3 import java.util.List;
  4 import ciir.umass.edu.learning.DataPoint;
  5 /**
  6  * 
  7  * @author vdang
  8  *
  9  */
 10 //Tree node,節點類,用於:
 11 // 1)訓練時候的分裂判斷(利用FeatureHistogram類);
 12 // 2)存儲該節點的分裂規則(featureID,threshold)以及該節點的輸出(avgLabel,deviance等)
 13 public class Split {
 14     //Key attributes of a split (tree node)
 15     //存儲該節點的分裂規則(featureID,threshold)以及該節點的輸出(avgLabel,deviance等)
 16     private int featureID = -1;
 17     private float threshold = 0F;
 18     private double avgLabel = 0.0F;
 19     
 20     //Intermediate variables (ONLY used during learning)
 21     //*DO NOT* attempt to access them once the training is done
 22     private boolean isRoot = false;
 23     private double sumLabel = 0.0;
 24     private double sqSumLabel = 0.0;
 25     private Split left = null;
 26     private Split right = null;
 27     private double deviance = 0F;//mean squared error "S"
 28     private int[][] sortedSampleIDs = null;
 29     public int[] samples = null;//訓練時候,該節點上的訓練數據集的索引
 30     public FeatureHistogram hist = null;//訓練時候,該節點上的訓練數據集的特徵統計直方圖
 31     
 32     public Split()
 33     {
 34         
 35     }
 36     public Split(int featureID, float threshold, double deviance)
 37     {
 38         this.featureID = featureID;
 39         this.threshold = threshold;
 40         this.deviance = deviance;
 41     }
 42     public Split(int[][] sortedSampleIDs, double deviance, double sumLabel, double sqSumLabel)
 43     {
 44         this.sortedSampleIDs = sortedSampleIDs;
 45         this.deviance = deviance;
 46         this.sumLabel = sumLabel;
 47         this.sqSumLabel = sqSumLabel;
 48         avgLabel = sumLabel/sortedSampleIDs[0].length;
 49     }
 50     public Split(int[] samples, FeatureHistogram hist, double deviance, double sumLabel)
 51     {
 52         this.samples = samples;
 53         this.hist = hist;
 54         this.deviance = deviance;
 55         this.sumLabel = sumLabel;
 56         avgLabel = sumLabel/samples.length;
 57     }
 58     
 59     //通常在該節點進行分裂完成後,設定分裂時的featureID,threshold,deviance。
 60     //只有非葉子節點纔會進行分裂(調用這個方法),因此只有非葉子節點的featureID不爲-1,葉子節點因爲沒有調用這個方法,故featureID=-1
 61     public void set(int featureID, float threshold, double deviance)
 62     {
 63         this.featureID = featureID;
 64         this.threshold = threshold;
 65         this.deviance = deviance;
 66     }
 67     public void setLeft(Split s)
 68     {
 69         left = s;
 70     }
 71     public void setRight(Split s)
 72     {
 73         right = s;
 74     }
 75     public void setOutput(float output)
 76     {
 77         avgLabel = output;
 78     }
 79     
 80     public Split getLeft()
 81     {
 82         return left;
 83     }
 84     public Split getRight()
 85     {
 86         return right;
 87     }
 88     public double getDeviance()
 89     {
 90         return deviance;
 91     }
 92     public double getOutput()
 93     {
 94         return avgLabel;
 95     }
 96     
 97     //獲得此節點(通常是根節點)下的全部葉子節點的list
 98     //採用了遞歸的方法,碰到葉子節點(featureID=-1)則加入到list中,不然遞歸地調用leaves(list),
 99     public List<Split> leaves()
100     {
101         List<Split> list = new ArrayList<Split>();
102         leaves(list);
103         return list;        
104     }
105     private void leaves(List<Split> leaves)
106     {
107         if(featureID == -1)
108             leaves.add(this);
109         else
110         {
111             left.leaves(leaves);
112             right.leaves(leaves);
113         }
114     }
115     
116     //獲得一個DataPoint在此節點(通常是根節點)下的最終落入(每層都按照分裂規則進入下一層)的葉子節點的輸出值(avgLabel值)
117     public double eval(DataPoint dp)
118     {
119         Split n = this;
120         while(n.featureID != -1)
121         {
122             if(dp.getFeatureValue(n.featureID) <= n.threshold)
123                 n = n.left;
124             else
125                 n = n.right;
126         }
127         return n.avgLabel;
128     }
129     
130     public String toString()
131     {
132         return toString("");
133     }
134     public String toString(String indent)
135     {
136         String strOutput = indent + "<split>" + "\n";
137         strOutput += getString(indent + "\t");
138         strOutput += indent + "</split>" + "\n";
139         return strOutput;
140     }
141     public String getString(String indent)
142     {
143         String strOutput = "";
144         if(featureID == -1)
145         {
146             strOutput += indent + "<output> " + avgLabel + " </output>" + "\n";
147         }
148         else
149         {
150             strOutput += indent + "<feature> " + featureID + " </feature>" + "\n";
151             strOutput += indent + "<threshold> " + threshold + " </threshold>" + "\n";
152             strOutput += indent + "<split pos=\"left\">" + "\n";
153             strOutput += left.getString(indent + "\t");
154             strOutput += indent + "</split>" + "\n";
155             strOutput += indent + "<split pos=\"right\">" + "\n";
156             strOutput += right.getString(indent + "\t");
157             strOutput += indent + "</split>" + "\n";
158         }
159         return strOutput;
160     }
161     //Internal functions(ONLY used during learning)
162     //*DO NOT* attempt to call them once the training is done
163     //*重要*,訓練時候,在該節點上進行分裂,調用了該節點的特徵統計直方圖對象的方法findBestSplit
164     public boolean split(double[] trainingLabels, int minLeafSupport)
165     {
166         return hist.findBestSplit(this, trainingLabels, minLeafSupport);
167     }
168     public int[] getSamples()
169     {
170         if(sortedSampleIDs != null)
171             return sortedSampleIDs[0];
172         return samples;
173     }
174     public int[][] getSampleSortedIndex()
175     {
176         return sortedSampleIDs;
177     }
178     public double getSumLabel()
179     {
180         return sumLabel;
181     }
182     public double getSqSumLabel()
183     {
184         return sqSumLabel;
185     }
186     public void clearSamples()
187     {
188         sortedSampleIDs = null;
189         samples = null;
190         hist = null;
191     }
192     public void setRoot(boolean isRoot)
193     {
194         this.isRoot = isRoot;
195     }
196     public boolean isRoot()
197     {
198         return isRoot;
199     }
200 }

 

3. RegressionTreeapp

  1 package ciir.umass.edu.learning.tree;
  2 import java.util.ArrayList;
  3 import java.util.List;
  4 import ciir.umass.edu.learning.DataPoint;
  5 /**
  6  * @author vdang
  7  */
  8 //迴歸樹類
  9 public class RegressionTree {
 10     
 11     //Parameters
 12     protected int nodes = 10;//-1 for unlimited number of nodes (the size of the tree will then be controlled *ONLY* by minLeafSupport)
 13     protected int minLeafSupport = 1; //控制分裂的次數,若是某個節點所包含的訓練數據小於2*minLeafSupport ,則該節點再也不分裂
 14     
 15     //Member variables and functions 
 16     protected Split root = null; //根節點
 17     protected List<Split> leaves = null; //葉子節點list
 18     
 19     protected DataPoint[] trainingSamples = null;
 20     protected double[] trainingLabels = null;
 21     protected int[] features = null;
 22     protected float[][] thresholds = null; //二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是閾值,個數爲全部訓練數據在此feature上的value的去重個數,從小到大排序的不重複值,用於對此節點的訓練數據在此feature上分裂時可選的feature value閾值
 23     protected int[] index = null;
 24     protected FeatureHistogram hist = null;
 25     
 26     public RegressionTree(Split root)
 27     {
 28         this.root = root;
 29         leaves = root.leaves();
 30     }
 31     public RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport)
 32     {
 33         this.nodes = nLeaves;
 34         this.trainingSamples = trainingSamples;
 35         this.trainingLabels = labels;
 36         this.hist = hist;
 37         this.minLeafSupport = minLeafSupport;
 38         index = new int[trainingSamples.length];
 39         for(int i=0;i<trainingSamples.length;i++)
 40             index[i] = i;
 41     }
 42     
 43     /**
 44      * Fit the tree from the specified training data
 45      */
 46     public void fit()
 47     {
 48         List<Split> queue = new ArrayList<Split>(); //用於按隊列順序(即按層遍歷的順序)進行分裂
 49         root = new Split(index, hist, Float.MAX_VALUE, 0); //迴歸樹的根節點
 50         root.setRoot(true);
 51         root.split(trainingLabels, minLeafSupport); //根節點分裂1次,下面多了2個子節點
 52         insert(queue, root.getLeft()); //將左子節點插入隊列,用於下面遍歷
 53         insert(queue, root.getRight()); //將右子節點插入隊列,用於下面遍歷
 54         //循環:按隊列順序(即按層遍歷的順序)進行分裂,再將每次可以成功分裂的產生的兩個子節點插入隊列中
 55         int taken = 0;
 56         while( (nodes == -1 || taken + queue.size() < nodes) && queue.size() > 0)
 57         {
 58             Split leaf = queue.get(0);
 59             queue.remove(0);
 60             
 61             if(leaf.getSamples().length < 2 * minLeafSupport)
 62             {
 63                 taken++;
 64                 continue;
 65             }
 66             
 67             if(!leaf.split(trainingLabels, minLeafSupport))//unsplitable (i.e. variance(s)==0; or after-split variance is higher than before) 對每一個遍歷到的節點,進行1次分裂,下面多了2個子節點
 68                 taken++;
 69             else
 70             {
 71                 insert(queue, leaf.getLeft()); //將左子節點插入隊列,用於下面遍歷
 72                 insert(queue, leaf.getRight()); //將右子節點插入隊列,用於下面遍歷
 73             }            
 74         }
 75         leaves = root.leaves();
 76     }
 77     
 78     /**
 79      * Get the tree output for the input sample
 80      * @param dp
 81      * @return
 82      */
 83     public double eval(DataPoint dp)
 84     {
 85         return root.eval(dp);
 86     }
 87     /**
 88      * Retrieve all leave nodes in the tree
 89      * @return
 90      */
 91     public List<Split> leaves()
 92     {
 93         return leaves;
 94     }
 95     /**
 96      * Clear samples associated with each leaves (when they are no longer necessary) in order to save memory
 97      */
 98     public void clearSamples()
 99     {
100         trainingSamples = null;
101         trainingLabels = null;
102         features = null;
103         thresholds = null;
104         index = null;
105         hist = null;
106         for(int i=0;i<leaves.size();i++)
107             leaves.get(i).clearSamples();
108     }
109     
110     /**
111      * Generate the string representation of the tree
112      */
113     public String toString()
114     {
115         if(root != null)
116             return root.toString();
117         return "";
118     }
119     public String toString(String indent)
120     {
121         if(root != null)
122             return root.toString(indent);
123         return "";
124     }
125     
126     public double variance()
127     {
128         double var = 0;
129         for(int i=0;i<leaves.size();i++)
130             var += leaves.get(i).getDeviance();
131         return var;
132     }
133     protected void insert(List<Split> ls, Split s)
134     {
135         int i=0;
136         while(i < ls.size())
137         {
138             if(ls.get(i).getDeviance() > s.getDeviance()) //按均方偏差從大到小的順序進行插入隊列
139                 i++;
140             else
141                 break;
142         }
143         ls.add(i, s);
144     }
145 }

 

2. LambdaMart

LambdaMart模型訓練過程總結歸納以下:dom

 1 LambdaMart
 2     init
 3         初始化訓練數據:martSamples,modelScores,pseudoResponses,weights
 4         將樣本根據特徵排序,方便作樹的分裂時快速找出最優分裂點:sortedIdx
 5         初始化二維數組:thresholds(第一維是feature,下標是相應的features的下標,不是feature id;第二維是閾值,個數爲全部訓練數據在此feature上的value的去重個數,從小到大排序的不重複值,用於對此節點的訓練數據在此feature上分裂時可選的feature value閾值)
 6         hist.construct #根據訓練數據以及thresholds二維數組,初始化一個FeatureHistogram對象,用於構造總體數據的特徵統計直方圖,用於在根節點上進行分裂
 7             初始化:
 8                 sum #二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是label之和,是全部訓練數據中在此feature上的value小於等於相應位置的threshold值(thresholds[i][j])的DataPoint的label之和,sum二維數組大小與thresholds數組相同
 9                 count #二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是個數,是全部訓練數據中在此feature上的value小於等於相應位置的threshold值(thresholds[i][j])的DataPoint的個數,count二維數組大小與thresholds數組相同
10                 sampleToThresholdMap #二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是索引,是對應訓練數據samples[i][j]在特定feature上每一個訓練數據的value對應於其在thresholds數組中相應行的列索引位置
11                 sumResponse #全部的訓練數據的label之和
12                 sqSumResponse #全部的訓練數據的label的平方和
13     learn
14         初始化一個Ensemble對象ensemble
15         開始Gradient Boosting過程,即依次構造若干棵regression tree:
16             computePseudoResponses #計算本輪迭代中,每一個instance須要擬合的pseudo responses值(即梯度值,lambda)
17                 根據LambdaMart的梯度計算公式進行計算
18             hist.update #根據本輪迭代中計算獲得的pseudo responses值(即梯度值,lambda),更新特徵統計直方圖,由於只改變了training data中每一個instance的label,而其餘值(如features)並未改變
19             初始化一棵regression tree(根據訓練數據和特徵統計直方圖)
20             rt.fit #用regression tree對訓練數據+本輪迭代中的pseudo responses值(即梯度值,lambda)進行擬合
21             將本輪迭代擬合產生的regression tree加入到ensembel對象中
22             updateTreeOutput #更新本輪迭代中擬合數據的regression tree的各個葉子節點的輸出
23             計算本輪迭代後(新regression tree已經加入到集成模型中),training data中各個instance的預測分:modelScores
24             computeModelScoreOnTraining #計算本輪迭代後,最新模型對於training data整體的排序評價分(例如NDCG)
25             計算本輪迭代後(新regression tree已經加入到集成模型中),validation data中各個instance的預測分:modelScoresOnValidation
26             computeModelScoreOnValidation #計算本輪迭代後,最新模型對於validation data整體的排序評價分(例如NDCG)
27             更新在validation data上的歷次各個模型的最優排序評價分:bestScoreOnValidationData,以及最優模型編號:bestModelOnValidation
28             若是在連續若干輪迭代中,模型在validation data上的排序評價分都沒有提升,則終止迭代
29         回滾到在驗證集上的最優模型
30         計算最優模型在training data和validation data上的排序評價分

 

下面是LambdaMart訓練過程的代碼,關鍵部分都有添加了詳細的註釋。ide

 

1. LambdaMART函數

  1 package ciir.umass.edu.learning.tree;
  2 import ciir.umass.edu.learning.DataPoint;
  3 import ciir.umass.edu.learning.RankList;
  4 import ciir.umass.edu.learning.Ranker;
  5 import ciir.umass.edu.metric.MetricScorer;
  6 import ciir.umass.edu.utilities.MergeSorter;
  7 import ciir.umass.edu.utilities.MyThreadPool;
  8 import ciir.umass.edu.utilities.RankLibError;
  9 import ciir.umass.edu.utilities.SimpleMath;
 10 import java.io.BufferedReader;
 11 import java.io.StringReader;
 12 import java.util.ArrayList;
 13 import java.util.Arrays;
 14 import java.util.List;
 15 /**
 16  * @author vdang
 17  *
 18  *  This class implements LambdaMART.
 19  *  Q. Wu, C.J.C. Burges, K. Svore and J. Gao. Adapting Boosting for Information Retrieval Measures. 
 20  *  Journal of Information Retrieval, 2007.
 21  */
 22 public class LambdaMART extends Ranker {
 23     //Parameters
 24     public static int nTrees = 1000;//the number of trees
 25     public static float learningRate = 0.1F;//or shrinkage
 26     public static int nThreshold = 256;
 27     public static int nRoundToStopEarly = 100;//If no performance gain on the *VALIDATION* data is observed in #rounds, stop the training process right away. 
 28     public static int nTreeLeaves = 10;
 29     public static int minLeafSupport = 1;
 30     
 31     //for debugging
 32     public static int gcCycle = 100;
 33     
 34     //Local variables
 35     protected float[][] thresholds = null;
 36     protected Ensemble ensemble = null;
 37     protected double[] modelScores = null;//on training data
 38     
 39     protected double[][] modelScoresOnValidation = null;
 40     protected int bestModelOnValidation = Integer.MAX_VALUE-2;
 41     
 42     //Training instances prepared for MART
 43     protected DataPoint[] martSamples = null;//Need initializing only once
 44     protected int[][] sortedIdx = null;//sorted list of samples in @martSamples by each feature -- Need initializing only once 
 45     protected FeatureHistogram hist = null;
 46     protected double[] pseudoResponses = null;//different for each iteration
 47     protected double[] weights = null;//different for each iteration
 48     
 49     public LambdaMART()
 50     {        
 51     }
 52     public LambdaMART(List<RankList> samples, int[] features, MetricScorer scorer)
 53     {
 54         super(samples, features, scorer);
 55     }
 56     
 57     public void init()
 58     {
 59         PRINT("Initializing... ");        
 60         //initialize samples for MART
 61         int dpCount = 0;
 62         for(int i=0;i<samples.size();i++)
 63         {
 64             RankList rl = samples.get(i);
 65             dpCount += rl.size();
 66         }
 67         int current = 0;
 68         martSamples = new DataPoint[dpCount];
 69         modelScores = new double[dpCount];
 70         pseudoResponses = new double[dpCount];
 71         weights = new double[dpCount];
 72         for(int i=0;i<samples.size();i++)
 73         {
 74             RankList rl = samples.get(i);
 75             for(int j=0;j<rl.size();j++)
 76             {
 77                 martSamples[current+j] = rl.get(j);
 78                 modelScores[current+j] = 0.0F;
 79                 pseudoResponses[current+j] = 0.0F;
 80                 weights[current+j] = 0;
 81             }
 82             current += rl.size();
 83         }            
 84         
 85         //sort (MART) samples by each feature so that we can quickly retrieve a sorted list of samples by any feature later on.
 86         // 將樣本根據特徵排序,方便作樹的分裂時快速找出最優分裂點
 87         sortedIdx = new int[features.length][];
 88         MyThreadPool p = MyThreadPool.getInstance();
 89         if(p.size() == 1)//single-thread
 90             sortSamplesByFeature(0, features.length-1);
 91         else//multi-thread
 92         {
 93             int[] partition = p.partition(features.length);
 94             for(int i=0;i<partition.length-1;i++)
 95                 p.execute(new SortWorker(this, partition[i], partition[i+1]-1));
 96             p.await();
 97         }
 98         
 99         //Create a table of candidate thresholds (for each feature). Later on, we will select the best tree split from these candidates        // 建立存放候選閾值(分裂點)的表
100         thresholds = new float[features.length][];
101         for(int f=0;f<features.length;f++)
102         {
103             //For this feature, keep track of the list of unique values and the max/min 
104             List<Float> values = new ArrayList<Float>();
105             float fmax = Float.NEGATIVE_INFINITY;
106             float fmin = Float.MAX_VALUE;
107             for(int i=0;i<martSamples.length;i++)
108             {
109                 int k = sortedIdx[f][i];//get samples sorted with respect to this feature
110                 float fv = martSamples[k].getFeatureValue(features[f]);
111                 values.add(fv);
112                 if(fmax < fv)
113                     fmax = fv;
114                 if(fmin > fv)
115                     fmin = fv;
116                 //skip all samples with the same feature value
117                 int j=i+1;
118                 while(j < martSamples.length)
119                 {
120                     if(martSamples[sortedIdx[f][j]].getFeatureValue(features[f]) > fv)
121                         break;
122                     j++;
123                 }
124                 i = j-1;//[i, j] gives the range of samples with the same feature value
125             }
126             
127             if(values.size() <= nThreshold || nThreshold == -1)
128             {
129                 thresholds[f] = new float[values.size()+1];
130                 for(int i=0;i<values.size();i++)
131                     thresholds[f][i] = values.get(i);
132                 thresholds[f][values.size()] = Float.MAX_VALUE;
133             }
134             else
135             {
136                 float step = (Math.abs(fmax - fmin))/nThreshold;
137                 thresholds[f] = new float[nThreshold+1];
138                 thresholds[f][0] = fmin;
139                 for(int j=1;j<nThreshold;j++)
140                     thresholds[f][j] = thresholds[f][j-1] + step;
141                 thresholds[f][nThreshold] = Float.MAX_VALUE;
142             }
143         }
144         
145         if(validationSamples != null)
146         {
147             modelScoresOnValidation = new double[validationSamples.size()][];
148             for(int i=0;i<validationSamples.size();i++)
149             {
150                 modelScoresOnValidation[i] = new double[validationSamples.get(i).size()];
151                 Arrays.fill(modelScoresOnValidation[i], 0);
152             }
153         }
154         
155         //compute the feature histogram (this is used to speed up the procedure of finding the best tree split later on)
156         // 計算特徵直方圖,加速尋找分裂點
157         hist = new FeatureHistogram();
158         hist.construct(martSamples, pseudoResponses, sortedIdx, features, thresholds);
159         //we no longer need the sorted indexes of samples
160         sortedIdx = null;
161         
162         System.gc();
163         PRINTLN("[Done]");
164     }
165     public void learn()
166     {
167         ensemble = new Ensemble();
168         
169         PRINTLN("---------------------------------");
170         PRINTLN("Training starts...");
171         PRINTLN("---------------------------------");
172         PRINTLN(new int[]{7, 9, 9}, new String[]{"#iter", scorer.name()+"-T", scorer.name()+"-V"});
173         PRINTLN("---------------------------------");        
174         
175         //Start the gradient boosting process
176         for(int m=0; m<nTrees; m++)
177         {
178             PRINT(new int[]{7}, new String[]{(m+1)+""});
179             
180             //Compute lambdas (which act as the "pseudo responses")
181             //Create training instances for MART:
182             //  - Each document is a training sample
183             //    - The lambda for this document serves as its training label
184             // 計算lambdas (pseudo responses)
185             computePseudoResponses();
186             
187             //update the histogram with these training labels (the feature histogram will be used to find the best tree split)
188             // 根據新的label更新特徵直方圖
189             hist.update(pseudoResponses);
190         
191             //Fit a regression tree        
192             // 迴歸決策樹    
193             RegressionTree rt = new RegressionTree(nTreeLeaves, martSamples, pseudoResponses, hist, minLeafSupport);
194             rt.fit();
195             
196             //Add this tree to the ensemble (our model)
197             // 將新生成的樹加入模型
198             ensemble.add(rt, learningRate);
199             //update the outputs of the tree (with gamma computed using the Newton-Raphson method) 
200             // 更新樹的輸出
201             updateTreeOutput(rt);
202             
203             //Update the model's outputs on all training samples
204             // 更新全部訓練樣本的模型輸出
205             List<Split> leaves = rt.leaves();
206             for(int i=0;i<leaves.size();i++)
207             {
208                 Split s = leaves.get(i);
209                 int[] idx = s.getSamples();
210                 for(int j=0;j<idx.length;j++)
211                     modelScores[idx[j]] += learningRate * s.getOutput();
212             }
213             //clear references to data that is no longer used
214             rt.clearSamples();
215             
216             //beg the garbage collector to work...
217             if(m % gcCycle == 0)
218                 System.gc();//this call is expensive. We shouldn't do it too often.
219             //Evaluate the current model
220             // 評價模型
221             scoreOnTrainingData = computeModelScoreOnTraining();
222             //**** NOTE ****
223             //The above function to evaluate the current model on the training data is equivalent to a single call:
224             //
225             //        scoreOnTrainingData = scorer.score(rank(samples);
226             //
227             //However, this function is more efficient since it uses the cached outputs of the model (as opposed to re-evaluating the model 
228             //on the entire training set).
229             
230             PRINT(new int[]{9}, new String[]{SimpleMath.round(scoreOnTrainingData, 4) + ""});            
231             
232             //Evaluate the current model on the validation data (if available)
233             if(validationSamples != null)
234             {
235                 //Update the model's scores on all validation samples
236                 for(int i=0;i<modelScoresOnValidation.length;i++)
237                     for(int j=0;j<modelScoresOnValidation[i].length;j++)
238                         modelScoresOnValidation[i][j] += learningRate * rt.eval(validationSamples.get(i).get(j));
239                 
240                 //again, equivalent to scoreOnValidation=scorer.score(rank(validationSamples)), but more efficient since we use the cached models' outputs
241                 double score = computeModelScoreOnValidation();
242                 
243                 PRINT(new int[]{9}, new String[]{SimpleMath.round(score, 4) + ""});
244                 if(score > bestScoreOnValidationData)
245                 {
246                     bestScoreOnValidationData = score;
247                     bestModelOnValidation = ensemble.treeCount()-1;
248                 }
249             }
250             
251             PRINTLN("");
252             
253             //Should we stop early?
254             // 檢驗是否提早結束
255             if(m - bestModelOnValidation > nRoundToStopEarly)
256                 break;
257         }
258         
259         //Rollback to the best model observed on the validation data
260         // 回滾到在驗證集上的最優模型
261         while(ensemble.treeCount() > bestModelOnValidation+1)
262             ensemble.remove(ensemble.treeCount()-1);
263         
264         //Finishing up
265         scoreOnTrainingData = scorer.score(rank(samples));
266         PRINTLN("---------------------------------");
267         PRINTLN("Finished sucessfully.");
268         PRINTLN(scorer.name() + " on training data: " + SimpleMath.round(scoreOnTrainingData, 4));
269         if(validationSamples != null)
270         {
271             bestScoreOnValidationData = scorer.score(rank(validationSamples));
272             PRINTLN(scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4));
273         }
274         PRINTLN("---------------------------------");
275     }
276     public double eval(DataPoint dp)
277     {
278         return ensemble.eval(dp);
279     }    
280     public Ranker createNew()
281     {
282         return new LambdaMART();
283     }
284     public String toString()
285     {
286         return ensemble.toString();
287     }
288     public String model()
289     {
290         String output = "## " + name() + "\n";
291         output += "## No. of trees = " + nTrees + "\n";
292         output += "## No. of leaves = " + nTreeLeaves + "\n";
293         output += "## No. of threshold candidates = " + nThreshold + "\n";
294         output += "## Learning rate = " + learningRate + "\n";
295         output += "## Stop early = " + nRoundToStopEarly + "\n";
296         output += "\n";
297         output += toString();
298         return output;
299     }
300         @Override
301     public void loadFromString(String fullText)
302     {
303         try {
304             String content = "";
305             //String model = "";
306                         StringBuffer model = new StringBuffer ();
307             BufferedReader in = new BufferedReader(new StringReader(fullText));
308             while((content = in.readLine()) != null)
309             {
310                 content = content.trim();
311                 if(content.length() == 0)
312                     continue;
313                 if(content.indexOf("##")==0)
314                     continue;
315                 //actual model component
316                 //model += content;
317                                 model.append (content);
318             }
319             in.close();
320             //load the ensemble
321             ensemble = new Ensemble(model.toString());
322             features = ensemble.getFeatures();
323         }
324         catch(Exception ex)
325         {
326             throw RankLibError.create("Error in LambdaMART::load(): ", ex);
327         }
328     }
329     public void printParameters()
330     {
331         PRINTLN("No. of trees: " + nTrees);
332         PRINTLN("No. of leaves: " + nTreeLeaves);
333         PRINTLN("No. of threshold candidates: " + nThreshold);
334         PRINTLN("Min leaf support: " + minLeafSupport);
335         PRINTLN("Learning rate: " + learningRate);
336         PRINTLN("Stop early: " + nRoundToStopEarly + " rounds without performance gain on validation data");        
337     }    
338     public String name()
339     {
340         return "LambdaMART";
341     }
342     public Ensemble getEnsemble()
343     {
344         return ensemble;
345     }
346     
347     protected void computePseudoResponses()
348     {
349         Arrays.fill(pseudoResponses, 0F);
350         Arrays.fill(weights, 0);
351         MyThreadPool p = MyThreadPool.getInstance();
352         if(p.size() == 1)//single-thread
353             computePseudoResponses(0, samples.size()-1, 0);
354         else //multi-threading
355         {
356             List<LambdaComputationWorker> workers = new ArrayList<LambdaMART.LambdaComputationWorker>();
357             //divide the entire dataset into chunks of equal size for each worker thread
358             int[] partition = p.partition(samples.size());
359             int current = 0;
360             for(int i=0;i<partition.length-1;i++)
361             {
362                 //execute the worker
363                 LambdaComputationWorker wk = new LambdaComputationWorker(this, partition[i], partition[i+1]-1, current); 
364                 workers.add(wk);//keep it so we can get back results from it later on
365                 p.execute(wk);
366                 
367                 if(i < partition.length-2)
368                     for(int j=partition[i]; j<=partition[i+1]-1;j++)
369                         current += samples.get(j).size();
370             }
371             
372             //wait for all workers to complete before we move on to the next stage
373             p.await();
374         }
375     }
376     protected void computePseudoResponses(int start, int end, int current)
377     {
378         int cutoff = scorer.getK();
379         //compute the lambda for each document (a.k.a "pseudo response")
380         for(int i=start;i<=end;i++)
381         {
382             RankList orig = samples.get(i);            
383             int[] idx = MergeSorter.sort(modelScores, current, current+orig.size()-1, false);
384             RankList rl = new RankList(orig, idx, current);
385             double[][] changes = scorer.swapChange(rl);
386             //NOTE: j, k are indices in the sorted (by modelScore) list, not the original
387             // ==> need to map back with idx[j] and idx[k] 
388             for(int j=0;j<rl.size();j++)
389             {
390                 DataPoint p1 = rl.get(j);
391                 int mj = idx[j];
392                 for(int k=0;k<rl.size();k++)
393                 {
394                     if(j > cutoff && k > cutoff)//swaping these pair won't result in any change in target measures since they're below the cut-off point
395                         break;
396                     DataPoint p2 = rl.get(k);
397                     int mk = idx[k];
398                     if(p1.getLabel() > p2.getLabel())
399                     {
400                         double deltaNDCG = Math.abs(changes[j][k]);
401                         if(deltaNDCG > 0)
402                         {
403                             double rho = 1.0 / (1 + Math.exp(modelScores[mj] - modelScores[mk]));
404                             double lambda = rho * deltaNDCG;
405                             pseudoResponses[mj] += lambda;
406                             pseudoResponses[mk] -= lambda;
407                             double delta = rho * (1.0 - rho) * deltaNDCG;
408                             weights[mj] += delta;
409                             weights[mk] += delta;
410                         }
411                     }
412                 }
413             }
414             current += orig.size();
415         }
416     }
417     protected void updateTreeOutput(RegressionTree rt)
418     {
419         List<Split> leaves = rt.leaves();
420         for(int i=0;i<leaves.size();i++)
421         {
422             float s1 = 0F;
423             float s2 = 0F;
424             Split s = leaves.get(i);
425             int[] idx = s.getSamples();
426             for(int j=0;j<idx.length;j++)
427             {
428                 int k = idx[j];
429                 s1 += pseudoResponses[k];
430                 s2 += weights[k];
431             }
432             if(s2 == 0)
433                 s.setOutput(0);
434             else
435                 s.setOutput(s1/s2);
436         }
437     }
438     protected int[] sortSamplesByFeature(DataPoint[] samples, int fid)
439     {
440         double[] score = new double[samples.length];
441         for(int i=0;i<samples.length;i++)
442             score[i] = samples[i].getFeatureValue(fid);
443         int[] idx = MergeSorter.sort(score, true); 
444         return idx;
445     }
446     /**
447      * This function is equivalent to the inherited function rank(...), but it uses the cached model's outputs instead of computing them from scratch.
448      * @param rankListIndex
449      * @param current
450      * @return
451      */
452     protected RankList rank(int rankListIndex, int current)
453     {
454         RankList orig = samples.get(rankListIndex);    
455         double[] scores = new double[orig.size()];
456         for(int i=0;i<scores.length;i++)
457             scores[i] = modelScores[current+i];
458         int[] idx = MergeSorter.sort(scores, false);
459         return new RankList(orig, idx);
460     }
461     protected float computeModelScoreOnTraining() 
462     {
463         /*float s = 0;
464         int current = 0;    
465         MyThreadPool p = MyThreadPool.getInstance();
466         if(p.size() == 1)//single-thread
467             s = computeModelScoreOnTraining(0, samples.size()-1, current);
468         else
469         {
470             List<Worker> workers = new ArrayList<Worker>();
471             //divide the entire dataset into chunks of equal size for each worker thread
472             int[] partition = p.partition(samples.size());
473             for(int i=0;i<partition.length-1;i++)
474             {
475                 //execute the worker
476                 Worker wk = new Worker(this, partition[i], partition[i+1]-1, current);
477                 workers.add(wk);//keep it so we can get back results from it later on
478                 p.execute(wk);
479                 
480                 if(i < partition.length-2)
481                     for(int j=partition[i]; j<=partition[i+1]-1;j++)
482                         current += samples.get(j).size();
483             }        
484             //wait for all workers to complete before we move on to the next stage
485             p.await();
486             for(int i=0;i<workers.size();i++)
487                 s += workers.get(i).score;
488         }*/
489         float s = computeModelScoreOnTraining(0, samples.size()-1, 0);
490         s = s / samples.size();
491         return s;
492     }
493     protected float computeModelScoreOnTraining(int start, int end, int current) 
494     {
495         float s = 0;
496         int c = current;
497         for(int i=start;i<=end;i++)
498         {
499             s += scorer.score(rank(i, c));
500             c += samples.get(i).size();
501         }
502         return s;
503     }
504     protected float computeModelScoreOnValidation() 
505     {
506         /*float score = 0;
507         MyThreadPool p = MyThreadPool.getInstance();
508         if(p.size() == 1)//single-thread
509             score = computeModelScoreOnValidation(0, validationSamples.size()-1);
510         else
511         {
512             List<Worker> workers = new ArrayList<Worker>();
513             //divide the entire dataset into chunks of equal size for each worker thread
514             int[] partition = p.partition(validationSamples.size());
515             for(int i=0;i<partition.length-1;i++)
516             {
517                 //execute the worker
518                 Worker wk = new Worker(this, partition[i], partition[i+1]-1);
519                 workers.add(wk);//keep it so we can get back results from it later on
520                 p.execute(wk);
521             }        
522             //wait for all workers to complete before we move on to the next stage
523             p.await();
524             for(int i=0;i<workers.size();i++)
525                 score += workers.get(i).score;
526         }*/
527         float score = computeModelScoreOnValidation(0, validationSamples.size()-1);
528         return score/validationSamples.size();
529     }
530     protected float computeModelScoreOnValidation(int start, int end) 
531     {
532         float score = 0;
533         for(int i=start;i<=end;i++)
534         {
535             int[] idx = MergeSorter.sort(modelScoresOnValidation[i], false);
536             score += scorer.score(new RankList(validationSamples.get(i), idx));
537         }
538         return score;
539     }
540     
541     protected void sortSamplesByFeature(int fStart, int fEnd)
542     {
543         for(int i=fStart;i<=fEnd; i++)
544             sortedIdx[i] = sortSamplesByFeature(martSamples, features[i]);
545     }
546     //For multi-threading processing
547     class SortWorker implements Runnable {
548         LambdaMART ranker = null;
549         int start = -1;
550         int end = -1;
551         SortWorker(LambdaMART ranker, int start, int end)
552         {
553             this.ranker = ranker;
554             this.start = start;
555             this.end = end;
556         }        
557         public void run()
558         {
559             ranker.sortSamplesByFeature(start, end);
560         }
561     }
562     class LambdaComputationWorker implements Runnable {
563         LambdaMART ranker = null;
564         int rlStart = -1;
565         int rlEnd = -1;
566         int martStart = -1;
567         LambdaComputationWorker(LambdaMART ranker, int rlStart, int rlEnd, int martStart)
568         {
569             this.ranker = ranker;
570             this.rlStart = rlStart;
571             this.rlEnd = rlEnd;
572             this.martStart = martStart;
573         }        
574         public void run()
575         {
576             ranker.computePseudoResponses(rlStart, rlEnd, martStart);
577         }
578     }
579     class Worker implements Runnable {
580         LambdaMART ranker = null;
581         int rlStart = -1;
582         int rlEnd = -1;
583         int martStart = -1;
584         int type = -1;
585         
586         //compute score on validation
587         float score = 0;
588         
589         Worker(LambdaMART ranker, int rlStart, int rlEnd)
590         {
591             type = 3;
592             this.ranker = ranker;
593             this.rlStart = rlStart;
594             this.rlEnd = rlEnd;
595         }
596         Worker(LambdaMART ranker, int rlStart, int rlEnd, int martStart)
597         {
598             type = 4;
599             this.ranker = ranker;
600             this.rlStart = rlStart;
601             this.rlEnd = rlEnd;
602             this.martStart = martStart;
603         }
604         public void run()
605         {
606             if(type == 4)
607                 score = ranker.computeModelScoreOnTraining(rlStart, rlEnd, martStart);
608             else if(type == 3)
609                 score = ranker.computeModelScoreOnValidation(rlStart, rlEnd);
610         }
611     }
612 }

 

版權聲明:

   本文由笨兔勿應全部,發佈於http://www.cnblogs.com/bentuwuying。若是轉載,請註明出處,在未經做者贊成下將本文用於商業用途,將追究其法律責任。

相關文章
相關標籤/搜索