Ranklib是一套優秀的Learning to Rank領域的開源實現,其中有實現了MART,RankNet,RankBoost,LambdaMart,Random Forest等模型。其中由微軟發佈的LambdaMART是IR業內經常使用的Learning to Rank模型,本文主要介紹Ranklib中的LambdaMART模型的具體實現,用以幫助理解paper中闡述的方法。本文是基於version2.3版本的Ranklib來介紹的。html
LambdaMart的基本原理詳見以前的博客:http://www.cnblogs.com/bentuwuying/p/6690836.html。要知道LambdaMart是基於MART的,而MART又是由若干棵regression tree組合而成的。因此,咱們先來看看Ranklib中是如何實現regression tree的,以及在給定training data with labels的狀況下,regression tree是如何擬合的。java
regression tree擬合給定training data的步驟總結歸納以下:node
RegressionTree nodes #限制一棵樹的最大葉子節點數 minLeafSupport #控制分裂的次數,若是某個節點所包含的訓練數據小於2*minLeafSupport ,則該節點再也不分裂 root #根節點 leaves #葉子節點list 構造函數RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport) 對各個類變量進行初始化 fit #對training data進行擬合regression tree 新建一個隊列queue,用於按隊列順序(即按層遍歷的順序)進行分裂 初始化一個regression tree的根節點root root.split #根節點分裂 hist.findBestSplit #調用Split對象包含的FeatureHistogram對象的分裂方法(在該節點的已經統計好的特徵統計直方圖的基礎上,尋找最佳分裂點,進行分裂,再計算左右子節點的特徵統計直方圖,並對左右子節點進行初始化) 判斷deviance,爲0則分裂不成功 根據samplingRate決定usedFeatures(分裂時須要使用的features的索引) 調用內部的findBestSplit方法 在一個節點上,在usedFeatures中,根據該節點的特徵統計直方圖,來進行分裂時feature和threshold的選擇 S = sumLeft * sumLeft / countLeft + sumRight * sumRight / countRight 對每一個可選的劃分點(feature和threshold組合),求最大的S值,對應於均方偏差最小,是最優的劃分點 判斷劃分是否成功,若S=-1,則分裂不成功 對該節點上的每一個訓練數據,根據最優分裂點,進行左右子節點的分配 初始化分裂後左右子節點各自的特徵統計直方圖 construct #通常用做父節點分裂後產生的左子節點的特徵統計直方圖的構造函數(當使用父節點來構造時,thresholds數組不變,可是sum和count數組須要從新構造) construct #通常用做父節點分裂後產生的右子節點的特徵統計直方圖的構造函數 計算本節點和左右子節點的均方偏差 sp.set #調用FeatureHistogram對象所在的Split對象的方法 通常在該節點進行分裂完成後,設定分裂時的featureID,threshold,deviance 只有非葉子節點纔會進行分裂(調用這個方法),因此只有非葉子節點的featureID不爲-1,葉子節點因爲沒有調用這個方法,故featureID=-1 初始化左子節點(根據分裂到左子節點的訓練數據索引數組,左子節點的特徵統計直方圖,左子節點的均方偏差,左子節點的訓練數據label之和),並設置到當前節點的左子節點變量上 初始化右子節點(根據分裂到右子節點的訓練數據索引數組,右子節點的特徵統計直方圖,右子節點的均方偏差,右子節點的訓練數據label之和),並設置到當前節點的右子節點變量上 insert #將左右的子節點插入隊列,用於下面遍歷 按均方偏差從大到小的順序進行插入隊列 循環:按隊列順序(即按層遍歷的順序)進行分裂,再將每次可以成功分裂的產生的兩個子節點插入隊列中 根據根節點root的leaves類方法(迭代遍歷),設置regression tree的leaves類變量
下面是regression tree擬合過程當中涉及到的幾個類文件代碼,關鍵部分都有添加了詳細的註釋。api
1. FeatureHistogram數組
1 package ciir.umass.edu.learning.tree; 2 import java.util.ArrayList; 3 import java.util.Arrays; 4 import java.util.List; 5 import java.util.Random; 6 import ciir.umass.edu.learning.DataPoint; 7 import ciir.umass.edu.utilities.MyThreadPool; 8 import ciir.umass.edu.utilities.WorkerThread; 9 /** 10 * @author vdang 11 */ 12 //特徵直方圖類,對RankList對象進行特徵的直方圖統計,選擇每次split時最優的feature和劃分點 13 public class FeatureHistogram { 14 // 存放分裂時的featureIdx,thresholdIdx,以及評判是否最佳分裂的評分值sumLeft*sumLeft/countLeft + sumRight*sumRight/countRight 15 class Config { 16 int featureIdx = -1; 17 int thresholdIdx = -1; 18 double S = -1; 19 } 20 21 //Parameter 22 public static float samplingRate = 1; //採樣率,用於對分裂時使用的feature個數進行採樣,不使用全部的feature 23 24 //Variables 25 public int[] features = null; //feature數組,每一個元素是一個feature id(fid) 26 public float[][] thresholds = null; //二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是閾值,個數爲全部訓練數據在此feature上的value的去重個數,從小到大排序的不重複值,用於對此節點的訓練數據在此feature上分裂時可選的feature value閾值 27 public double[][] sum = null; //二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是label之和,是全部訓練數據中在此feature上的value小於等於相應位置的threshold值(thresholds[i][j])的DataPoint的label之和,sum二維數組大小與thresholds數組相同 28 public double sumResponse = 0; //全部的訓練數據的label之和 29 public double sqSumResponse = 0; //全部的訓練數據的label的平方和 30 public int[][] count = null; //二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是個數,是全部訓練數據中在此feature上的value小於等於相應位置的threshold值(thresholds[i][j])的DataPoint的個數,count二維數組大小與thresholds數組相同 31 public int[][] sampleToThresholdMap = null; //二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是索引,是對應訓練數據samples[i][j]在特定feature上每一個訓練數據的value對應於其在thresholds數組中相應行的列索引位置 32 33 //whether to re-use its parents @sum and @count instead of cleaning up the parent and re-allocate for the children. 34 //@sum and @count of any intermediate tree node (except for root) can be re-used. 35 private boolean reuseParent = false; 36 37 public FeatureHistogram() 38 { 39 40 } 41 42 //FeatureHistogram構造函數(1-1),通常用做整棵樹/根節點的feature histogram,計算該節點的特徵統計直方圖 43 //@samples: 訓練數據 44 //@labels: 訓練數據的label 45 //@sampleSortedIdx: 將樣本根據特徵排序,方便作樹的分列時快速找出最優分列點,sorted list of samples by each feature, need initializing only once,初始化可見LambdaMART.java中的init() 46 //@features: 訓練數據的特徵集合 47 //@thresholds: 建立存放候選閾值(分列點)的表,a table of candidate thresolds for each feature, we will select the best tree split from these candidates later on 48 ,初始化可見LambdaMART.java中的init(),此二維數組的每一行的最後一列的值是後加的,爲Float.MAX_VALUE 49 public void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, int[] features, float[][] thresholds) 50 { 51 this.features = features; 52 this.thresholds = thresholds; 53 54 sumResponse = 0; 55 sqSumResponse = 0; 56 57 sum = new double[features.length][]; 58 count = new int[features.length][]; 59 sampleToThresholdMap = new int[features.length][]; 60 61 //肯定是否使用多線程計算 62 MyThreadPool p = MyThreadPool.getInstance(); 63 if(p.size() == 1) 64 construct(samples, labels, sampleSortedIdx, thresholds, 0, features.length-1); 65 else 66 p.execute(new Worker(this, samples, labels, sampleSortedIdx, thresholds), features.length); 67 } 68 //FeatureHistogram構造函數(1-2),被(1-1)調用 69 protected void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, float[][] thresholds, int start, int end) 70 { 71 for(int i=start;i<=end;i++) //對於每一個feature 72 { 73 int fid = features[i]; // 獲取feature id 74 //get the list of samples associated with this node (sorted in ascending order with respect to the current feature) 75 int[] idx = sampleSortedIdx[i]; //根據此feature下的value從小到大排序後的訓練數據的索引數組 76 77 double sumLeft = 0; //累計此值,用於給sumLabel使用 78 float[] threshold = thresholds[i]; 79 double[] sumLabel = new double[threshold.length]; //對應前面sum二維數組的一行 80 int[] c = new int[threshold.length]; //對應前面count二維數組的一行 81 int[] stMap = new int[samples.length]; //對應前面sampleToThresholdMap二維數組的一行 82 83 int last = -1; 84 for(int t=0;t<threshold.length;t++) //對於每一個可選的split閾值 85 { 86 int j=last+1; 87 //find the first sample that exceeds the current threshold 88 for(;j<idx.length;j++) 89 { 90 int k = idx[j]; //獲取此DataPoint在samples數組中的索引 91 if(samples[k].getFeatureValue(fid) > threshold[t]) 92 break; 93 sumLeft += labels[k]; 94 if(i == 0) 95 { 96 sumResponse += labels[k]; 97 sqSumResponse += labels[k] * labels[k]; 98 } 99 stMap[k] = t; 100 } 101 last = j-1; 102 sumLabel[t] = sumLeft; 103 c[t] = last+1; 104 } 105 sampleToThresholdMap[i] = stMap; 106 sum[i] = sumLabel; 107 count[i] = c; 108 } 109 } 110 111 //update(1-1), update the histogram with these training labels (the feature histogram will be used to find the best tree split) 112 protected void update(double[] labels) 113 { 114 sumResponse = 0; 115 sqSumResponse = 0; 116 117 118 //肯定是否使用多線程計算 119 MyThreadPool p = MyThreadPool.getInstance(); 120 if(p.size() == 1) 121 update(labels, 0, features.length-1); 122 else 123 p.execute(new Worker(this, labels), features.length); 124 } 125 126 //update(1-2),被(1-1)調用 127 protected void update(double[] labels, int start, int end) 128 { 129 for(int f=start;f<=end;f++) 130 Arrays.fill(sum[f], 0); 131 for(int k=0;k<labels.length;k++) 132 { 133 for(int f=start;f<=end;f++) 134 { 135 int t = sampleToThresholdMap[f][k]; 136 sum[f][t] += labels[k]; 137 if(f == 0) 138 { 139 sumResponse += labels[k]; 140 sqSumResponse += labels[k]*labels[k]; 141 } 142 //count doesn't change, so no need to re-compute 143 } 144 } 145 for(int f=start;f<=end;f++) 146 { 147 for(int t=1;t<thresholds[f].length;t++) 148 sum[f][t] += sum[f][t-1]; 149 } 150 } 151 152 //FeatureHistogram構造函數(2-1),通常用做父節點分裂後產生的左子節點的特徵統計直方圖的構造函數 153 //當使用父節點來構造時,thresholds數組不變,可是sum和count數組須要從新構造 154 //@soi: 使用的訓練數據的索引位置 155 public void construct(FeatureHistogram parent, int[] soi, double[] labels) 156 { 157 this.features = parent.features; 158 this.thresholds = parent.thresholds; 159 sumResponse = 0; 160 sqSumResponse = 0; 161 sum = new double[features.length][]; 162 count = new int[features.length][]; 163 sampleToThresholdMap = parent.sampleToThresholdMap; 164 165 166 //肯定是否使用多線程計算 167 MyThreadPool p = MyThreadPool.getInstance(); 168 if(p.size() == 1) 169 construct(parent, soi, labels, 0, features.length-1); 170 else 171 p.execute(new Worker(this, parent, soi, labels), features.length); 172 } 173 174 //FeatureHistogram構造函數(2-2),被(2-1)調用 175 protected void construct(FeatureHistogram parent, int[] soi, double[] labels, int start, int end) 176 { 177 //init 178 for(int i=start;i<=end;i++) 179 { 180 float[] threshold = thresholds[i]; 181 sum[i] = new double[threshold.length]; 182 count[i] = new int[threshold.length]; 183 Arrays.fill(sum[i], 0); 184 Arrays.fill(count[i], 0); 185 } 186 187 //update 188 for(int i=0;i<soi.length;i++) 189 { 190 int k = soi[i]; 191 for(int f=start;f<=end;f++) 192 { 193 int t = sampleToThresholdMap[f][k]; 194 sum[f][t] += labels[k]; 195 count[f][t] ++; 196 if(f == 0) 197 { 198 sumResponse += labels[k]; 199 sqSumResponse += labels[k]*labels[k]; 200 } 201 } 202 } 203 204 for(int f=start;f<=end;f++) 205 { 206 for(int t=1;t<thresholds[f].length;t++) 207 { 208 sum[f][t] += sum[f][t-1]; 209 count[f][t] += count[f][t-1]; 210 } 211 } 212 } 213 214 //FeatureHistogram構造函數(3-1),通常用做父節點分裂後產生的右子節點的特徵統計直方圖的構造函數 215 public void construct(FeatureHistogram parent, FeatureHistogram leftSibling, boolean reuseParent) 216 { 217 this.reuseParent = reuseParent; 218 this.features = parent.features; 219 this.thresholds = parent.thresholds; 220 sumResponse = parent.sumResponse - leftSibling.sumResponse; 221 sqSumResponse = parent.sqSumResponse - leftSibling.sqSumResponse; 222 223 if(reuseParent) 224 { 225 sum = parent.sum; 226 count = parent.count; 227 } 228 else 229 { 230 sum = new double[features.length][]; 231 count = new int[features.length][]; 232 } 233 sampleToThresholdMap = parent.sampleToThresholdMap; 234 235 //肯定是否使用多線程計算 236 MyThreadPool p = MyThreadPool.getInstance(); 237 if(p.size() == 1) 238 construct(parent, leftSibling, 0, features.length-1); 239 else 240 p.execute(new Worker(this, parent, leftSibling), features.length); 241 } 242 243 //FeatureHistogram構造函數(3-2),被(3-1)調用 244 protected void construct(FeatureHistogram parent, FeatureHistogram leftSibling, int start, int end) 245 { 246 for(int f=start;f<=end;f++) 247 { 248 float[] threshold = thresholds[f]; 249 if(!reuseParent) 250 { 251 sum[f] = new double[threshold.length]; 252 count[f] = new int[threshold.length]; 253 } 254 for(int t=0;t<threshold.length;t++) 255 { 256 sum[f][t] = parent.sum[f][t] - leftSibling.sum[f][t]; 257 count[f][t] = parent.count[f][t] - leftSibling.count[f][t]; 258 } 259 } 260 } 261 262 //findBestSplit函數(1-2),被(1-1)調用。在一個節點上,在usedFeatures中,根據該節點的特徵統計直方圖,來進行分裂時feature和threshold的選擇 263 protected Config findBestSplit(int[] usedFeatures, int minLeafSupport, int start, int end) 264 { 265 Config cfg = new Config(); 266 int totalCount = count[start][count[start].length-1]; 267 for(int f=start;f<=end;f++) 268 { 269 int i = usedFeatures[f]; 270 float[] threshold = thresholds[i]; 271 272 for(int t=0;t<threshold.length;t++) 273 { 274 int countLeft = count[i][t]; 275 int countRight = totalCount - countLeft; 276 if(countLeft < minLeafSupport || countRight < minLeafSupport) 277 continue; 278 279 double sumLeft = sum[i][t]; 280 double sumRight = sumResponse - sumLeft; 281 282 double S = sumLeft * sumLeft / countLeft + sumRight * sumRight / countRight; 283 //求最大的S值,對應於均方偏差最小,是最優的劃分點 284 if(cfg.S < S) 285 { 286 cfg.S = S; 287 cfg.featureIdx = i; 288 cfg.thresholdIdx = t; 289 } 290 } 291 } 292 return cfg; 293 } 294 295 //findBestSplit函數(1-1),在該節點的已經統計好的特徵統計直方圖的基礎上,尋找最佳分裂點,進行分裂,再計算左右子節點的特徵統計直方圖,並對左右子節點進行初始化 296 public boolean findBestSplit(Split sp, double[] labels, int minLeafSupport) 297 { 298 if(sp.getDeviance() >= 0.0 && sp.getDeviance() <= 0.0)//equals 0 299 return false;//no need to split 300 301 int[] usedFeatures = null;//index of the features to be used for tree splitting 302 if(samplingRate < 1)//need to do sub sampling (feature sampling) 303 { 304 int size = (int)(samplingRate * features.length); 305 usedFeatures = new int[size]; 306 //put all features into a pool 307 List<Integer> fpool = new ArrayList<Integer>(); 308 for(int i=0;i<features.length;i++) 309 fpool.add(i); 310 //do sampling, without replacement 311 Random r = new Random(); 312 for(int i=0;i<size;i++) 313 { 314 int sel = r.nextInt(fpool.size()); 315 usedFeatures[i] = fpool.get(sel); 316 fpool.remove(sel); 317 } 318 } 319 else//no sub-sampling, all features will be used 320 { 321 usedFeatures = new int[features.length]; 322 for(int i=0;i<features.length;i++) 323 usedFeatures[i] = i; 324 } 325 326 //find the best split 327 Config best = new Config(); 328 //肯定是否使用多線程 329 MyThreadPool p = MyThreadPool.getInstance(); 330 if(p.size() == 1) 331 best = findBestSplit(usedFeatures, minLeafSupport, 0, usedFeatures.length-1); 332 else 333 { 334 WorkerThread[] workers = p.execute(new Worker(this, usedFeatures, minLeafSupport), usedFeatures.length); 335 for(int i=0;i<workers.length;i++) 336 { 337 Worker wk = (Worker)workers[i]; 338 if(best.S < wk.cfg.S) 339 best = wk.cfg; 340 } 341 } 342 343 if(best.S == -1)//unsplitable, for some reason... 344 return false; 345 346 //if(minS >= sp.getDeviance()) 347 //return null; 348 349 double[] sumLabel = sum[best.featureIdx]; 350 int[] sampleCount = count[best.featureIdx]; 351 352 double s = sumLabel[sumLabel.length-1]; 353 int c = sampleCount[sumLabel.length-1]; 354 355 double sumLeft = sumLabel[best.thresholdIdx]; 356 int countLeft = sampleCount[best.thresholdIdx]; 357 358 double sumRight = s - sumLeft; 359 int countRight = c - countLeft; 360 361 int[] left = new int[countLeft]; 362 int[] right = new int[countRight]; 363 int l = 0; 364 int r = 0; 365 int k = 0; 366 int[] idx = sp.getSamples(); 367 //對該節點上的每一個訓練數據,根據最優分裂點,進行左右子節點的分配 368 for(int j=0;j<idx.length;j++) 369 { 370 k = idx[j]; 371 if(sampleToThresholdMap[best.featureIdx][k] <= best.thresholdIdx)//go to the left 372 left[l++] = k; 373 else//go to the right 374 right[r++] = k; 375 } 376 377 //初始化分裂後左右子節點各自的特徵統計直方圖 378 FeatureHistogram lh = new FeatureHistogram(); 379 lh.construct(sp.hist, left, labels); //初始化左子節點的特徵統計直方圖 380 FeatureHistogram rh = new FeatureHistogram(); 381 rh.construct(sp.hist, lh, !sp.isRoot()); //初始化右子節點的特徵統計直方圖 382 double var = sqSumResponse - sumResponse * sumResponse / idx.length; //計算本節點的均方偏差 383 double varLeft = lh.sqSumResponse - lh.sumResponse * lh.sumResponse / left.length; //計算左子節點的均方偏差 384 double varRight = rh.sqSumResponse - rh.sumResponse * rh.sumResponse / right.length; //計算右子節點的均方偏差 385 386 sp.set(features[best.featureIdx], thresholds[best.featureIdx][best.thresholdIdx], var); 387 sp.setLeft(new Split(left, lh, varLeft, sumLeft)); 388 sp.setRight(new Split(right, rh, varRight, sumRight)); 389 390 sp.clearSamples(); //清理本節點所屬的sortedSampleIDs,samples,hist等數據 391 392 return true; 393 } 394 class Worker extends WorkerThread { 395 FeatureHistogram fh = null; 396 int type = -1; 397 398 //find best split (type == 0) 399 int[] usedFeatures = null; 400 int minLeafSup = -1; 401 Config cfg = null; 402 403 //update (type = 1) 404 double[] labels = null; 405 406 //construct (type = 2) 407 FeatureHistogram parent = null; 408 int[] soi = null; 409 410 //construct (type = 3) 411 FeatureHistogram leftSibling = null; 412 413 //construct (type = 4) 414 DataPoint[] samples; 415 int[][] sampleSortedIdx; 416 float[][] thresholds; 417 418 public Worker() 419 { 420 } 421 public Worker(FeatureHistogram fh, int[] usedFeatures, int minLeafSup) 422 { 423 type = 0; 424 this.fh = fh; 425 this.usedFeatures = usedFeatures; 426 this.minLeafSup = minLeafSup; 427 } 428 public Worker(FeatureHistogram fh, double[] labels) 429 { 430 type = 1; 431 this.fh = fh; 432 this.labels = labels; 433 } 434 public Worker(FeatureHistogram fh, FeatureHistogram parent, int[] soi, double[] labels) 435 { 436 type = 2; 437 this.fh = fh; 438 this.parent = parent; 439 this.soi = soi; 440 this.labels = labels; 441 } 442 public Worker(FeatureHistogram fh, FeatureHistogram parent, FeatureHistogram leftSibling) 443 { 444 type = 3; 445 this.fh = fh; 446 this.parent = parent; 447 this.leftSibling = leftSibling; 448 } 449 public Worker(FeatureHistogram fh, DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, float[][] thresholds) 450 { 451 type = 4; 452 this.fh = fh; 453 this.samples = samples; 454 this.labels = labels; 455 this.sampleSortedIdx = sampleSortedIdx; 456 this.thresholds = thresholds; 457 } 458 public void run() 459 { 460 if(type == 0) 461 cfg = fh.findBestSplit(usedFeatures, minLeafSup, start, end); 462 else if(type == 1) 463 fh.update(labels, start, end); 464 else if(type == 2) 465 fh.construct(parent, soi, labels, start, end); 466 else if(type == 3) 467 fh.construct(parent, leftSibling, start, end); 468 else if(type == 4) 469 fh.construct(samples, labels, sampleSortedIdx, thresholds, start, end); 470 } 471 public WorkerThread clone() 472 { 473 Worker wk = new Worker(); 474 wk.fh = fh; 475 wk.type = type; 476 477 //find best split (type == 0) 478 wk.usedFeatures = usedFeatures; 479 wk.minLeafSup = minLeafSup; 480 //wk.cfg = cfg; 481 482 //update (type = 1) 483 wk.labels = labels; 484 485 //construct (type = 2) 486 wk.parent = parent; 487 wk.soi = soi; 488 489 //construct (type = 3) 490 wk.leftSibling = leftSibling; 491 492 //construct (type = 1) 493 wk.samples = samples; 494 wk.sampleSortedIdx = sampleSortedIdx; 495 wk.thresholds = thresholds; 496 497 return wk; 498 } 499 } 500 }
2. Split多線程
1 package ciir.umass.edu.learning.tree; 2 import java.util.ArrayList; 3 import java.util.List; 4 import ciir.umass.edu.learning.DataPoint; 5 /** 6 * 7 * @author vdang 8 * 9 */ 10 //Tree node,節點類,用於: 11 // 1)訓練時候的分裂判斷(利用FeatureHistogram類); 12 // 2)存儲該節點的分裂規則(featureID,threshold)以及該節點的輸出(avgLabel,deviance等) 13 public class Split { 14 //Key attributes of a split (tree node) 15 //存儲該節點的分裂規則(featureID,threshold)以及該節點的輸出(avgLabel,deviance等) 16 private int featureID = -1; 17 private float threshold = 0F; 18 private double avgLabel = 0.0F; 19 20 //Intermediate variables (ONLY used during learning) 21 //*DO NOT* attempt to access them once the training is done 22 private boolean isRoot = false; 23 private double sumLabel = 0.0; 24 private double sqSumLabel = 0.0; 25 private Split left = null; 26 private Split right = null; 27 private double deviance = 0F;//mean squared error "S" 28 private int[][] sortedSampleIDs = null; 29 public int[] samples = null;//訓練時候,該節點上的訓練數據集的索引 30 public FeatureHistogram hist = null;//訓練時候,該節點上的訓練數據集的特徵統計直方圖 31 32 public Split() 33 { 34 35 } 36 public Split(int featureID, float threshold, double deviance) 37 { 38 this.featureID = featureID; 39 this.threshold = threshold; 40 this.deviance = deviance; 41 } 42 public Split(int[][] sortedSampleIDs, double deviance, double sumLabel, double sqSumLabel) 43 { 44 this.sortedSampleIDs = sortedSampleIDs; 45 this.deviance = deviance; 46 this.sumLabel = sumLabel; 47 this.sqSumLabel = sqSumLabel; 48 avgLabel = sumLabel/sortedSampleIDs[0].length; 49 } 50 public Split(int[] samples, FeatureHistogram hist, double deviance, double sumLabel) 51 { 52 this.samples = samples; 53 this.hist = hist; 54 this.deviance = deviance; 55 this.sumLabel = sumLabel; 56 avgLabel = sumLabel/samples.length; 57 } 58 59 //通常在該節點進行分裂完成後,設定分裂時的featureID,threshold,deviance。 60 //只有非葉子節點纔會進行分裂(調用這個方法),因此只有非葉子節點的featureID不爲-1,葉子節點因爲沒有調用這個方法,故featureID=-1 61 public void set(int featureID, float threshold, double deviance) 62 { 63 this.featureID = featureID; 64 this.threshold = threshold; 65 this.deviance = deviance; 66 } 67 public void setLeft(Split s) 68 { 69 left = s; 70 } 71 public void setRight(Split s) 72 { 73 right = s; 74 } 75 public void setOutput(float output) 76 { 77 avgLabel = output; 78 } 79 80 public Split getLeft() 81 { 82 return left; 83 } 84 public Split getRight() 85 { 86 return right; 87 } 88 public double getDeviance() 89 { 90 return deviance; 91 } 92 public double getOutput() 93 { 94 return avgLabel; 95 } 96 97 //獲得此節點(通常是根節點)下的全部葉子節點的list 98 //採用了遞歸的方法,碰到葉子節點(featureID=-1)則加入到list中,不然遞歸地調用leaves(list), 99 public List<Split> leaves() 100 { 101 List<Split> list = new ArrayList<Split>(); 102 leaves(list); 103 return list; 104 } 105 private void leaves(List<Split> leaves) 106 { 107 if(featureID == -1) 108 leaves.add(this); 109 else 110 { 111 left.leaves(leaves); 112 right.leaves(leaves); 113 } 114 } 115 116 //獲得一個DataPoint在此節點(通常是根節點)下的最終落入(每層都按照分裂規則進入下一層)的葉子節點的輸出值(avgLabel值) 117 public double eval(DataPoint dp) 118 { 119 Split n = this; 120 while(n.featureID != -1) 121 { 122 if(dp.getFeatureValue(n.featureID) <= n.threshold) 123 n = n.left; 124 else 125 n = n.right; 126 } 127 return n.avgLabel; 128 } 129 130 public String toString() 131 { 132 return toString(""); 133 } 134 public String toString(String indent) 135 { 136 String strOutput = indent + "<split>" + "\n"; 137 strOutput += getString(indent + "\t"); 138 strOutput += indent + "</split>" + "\n"; 139 return strOutput; 140 } 141 public String getString(String indent) 142 { 143 String strOutput = ""; 144 if(featureID == -1) 145 { 146 strOutput += indent + "<output> " + avgLabel + " </output>" + "\n"; 147 } 148 else 149 { 150 strOutput += indent + "<feature> " + featureID + " </feature>" + "\n"; 151 strOutput += indent + "<threshold> " + threshold + " </threshold>" + "\n"; 152 strOutput += indent + "<split pos=\"left\">" + "\n"; 153 strOutput += left.getString(indent + "\t"); 154 strOutput += indent + "</split>" + "\n"; 155 strOutput += indent + "<split pos=\"right\">" + "\n"; 156 strOutput += right.getString(indent + "\t"); 157 strOutput += indent + "</split>" + "\n"; 158 } 159 return strOutput; 160 } 161 //Internal functions(ONLY used during learning) 162 //*DO NOT* attempt to call them once the training is done 163 //*重要*,訓練時候,在該節點上進行分裂,調用了該節點的特徵統計直方圖對象的方法findBestSplit 164 public boolean split(double[] trainingLabels, int minLeafSupport) 165 { 166 return hist.findBestSplit(this, trainingLabels, minLeafSupport); 167 } 168 public int[] getSamples() 169 { 170 if(sortedSampleIDs != null) 171 return sortedSampleIDs[0]; 172 return samples; 173 } 174 public int[][] getSampleSortedIndex() 175 { 176 return sortedSampleIDs; 177 } 178 public double getSumLabel() 179 { 180 return sumLabel; 181 } 182 public double getSqSumLabel() 183 { 184 return sqSumLabel; 185 } 186 public void clearSamples() 187 { 188 sortedSampleIDs = null; 189 samples = null; 190 hist = null; 191 } 192 public void setRoot(boolean isRoot) 193 { 194 this.isRoot = isRoot; 195 } 196 public boolean isRoot() 197 { 198 return isRoot; 199 } 200 }
3. RegressionTreeapp
1 package ciir.umass.edu.learning.tree; 2 import java.util.ArrayList; 3 import java.util.List; 4 import ciir.umass.edu.learning.DataPoint; 5 /** 6 * @author vdang 7 */ 8 //迴歸樹類 9 public class RegressionTree { 10 11 //Parameters 12 protected int nodes = 10;//-1 for unlimited number of nodes (the size of the tree will then be controlled *ONLY* by minLeafSupport) 13 protected int minLeafSupport = 1; //控制分裂的次數,若是某個節點所包含的訓練數據小於2*minLeafSupport ,則該節點再也不分裂 14 15 //Member variables and functions 16 protected Split root = null; //根節點 17 protected List<Split> leaves = null; //葉子節點list 18 19 protected DataPoint[] trainingSamples = null; 20 protected double[] trainingLabels = null; 21 protected int[] features = null; 22 protected float[][] thresholds = null; //二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是閾值,個數爲全部訓練數據在此feature上的value的去重個數,從小到大排序的不重複值,用於對此節點的訓練數據在此feature上分裂時可選的feature value閾值 23 protected int[] index = null; 24 protected FeatureHistogram hist = null; 25 26 public RegressionTree(Split root) 27 { 28 this.root = root; 29 leaves = root.leaves(); 30 } 31 public RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport) 32 { 33 this.nodes = nLeaves; 34 this.trainingSamples = trainingSamples; 35 this.trainingLabels = labels; 36 this.hist = hist; 37 this.minLeafSupport = minLeafSupport; 38 index = new int[trainingSamples.length]; 39 for(int i=0;i<trainingSamples.length;i++) 40 index[i] = i; 41 } 42 43 /** 44 * Fit the tree from the specified training data 45 */ 46 public void fit() 47 { 48 List<Split> queue = new ArrayList<Split>(); //用於按隊列順序(即按層遍歷的順序)進行分裂 49 root = new Split(index, hist, Float.MAX_VALUE, 0); //迴歸樹的根節點 50 root.setRoot(true); 51 root.split(trainingLabels, minLeafSupport); //根節點分裂1次,下面多了2個子節點 52 insert(queue, root.getLeft()); //將左子節點插入隊列,用於下面遍歷 53 insert(queue, root.getRight()); //將右子節點插入隊列,用於下面遍歷 54 //循環:按隊列順序(即按層遍歷的順序)進行分裂,再將每次可以成功分裂的產生的兩個子節點插入隊列中 55 int taken = 0; 56 while( (nodes == -1 || taken + queue.size() < nodes) && queue.size() > 0) 57 { 58 Split leaf = queue.get(0); 59 queue.remove(0); 60 61 if(leaf.getSamples().length < 2 * minLeafSupport) 62 { 63 taken++; 64 continue; 65 } 66 67 if(!leaf.split(trainingLabels, minLeafSupport))//unsplitable (i.e. variance(s)==0; or after-split variance is higher than before) 對每一個遍歷到的節點,進行1次分裂,下面多了2個子節點 68 taken++; 69 else 70 { 71 insert(queue, leaf.getLeft()); //將左子節點插入隊列,用於下面遍歷 72 insert(queue, leaf.getRight()); //將右子節點插入隊列,用於下面遍歷 73 } 74 } 75 leaves = root.leaves(); 76 } 77 78 /** 79 * Get the tree output for the input sample 80 * @param dp 81 * @return 82 */ 83 public double eval(DataPoint dp) 84 { 85 return root.eval(dp); 86 } 87 /** 88 * Retrieve all leave nodes in the tree 89 * @return 90 */ 91 public List<Split> leaves() 92 { 93 return leaves; 94 } 95 /** 96 * Clear samples associated with each leaves (when they are no longer necessary) in order to save memory 97 */ 98 public void clearSamples() 99 { 100 trainingSamples = null; 101 trainingLabels = null; 102 features = null; 103 thresholds = null; 104 index = null; 105 hist = null; 106 for(int i=0;i<leaves.size();i++) 107 leaves.get(i).clearSamples(); 108 } 109 110 /** 111 * Generate the string representation of the tree 112 */ 113 public String toString() 114 { 115 if(root != null) 116 return root.toString(); 117 return ""; 118 } 119 public String toString(String indent) 120 { 121 if(root != null) 122 return root.toString(indent); 123 return ""; 124 } 125 126 public double variance() 127 { 128 double var = 0; 129 for(int i=0;i<leaves.size();i++) 130 var += leaves.get(i).getDeviance(); 131 return var; 132 } 133 protected void insert(List<Split> ls, Split s) 134 { 135 int i=0; 136 while(i < ls.size()) 137 { 138 if(ls.get(i).getDeviance() > s.getDeviance()) //按均方偏差從大到小的順序進行插入隊列 139 i++; 140 else 141 break; 142 } 143 ls.add(i, s); 144 } 145 }
LambdaMart模型訓練過程總結歸納以下:dom
1 LambdaMart 2 init 3 初始化訓練數據:martSamples,modelScores,pseudoResponses,weights 4 將樣本根據特徵排序,方便作樹的分裂時快速找出最優分裂點:sortedIdx 5 初始化二維數組:thresholds(第一維是feature,下標是相應的features的下標,不是feature id;第二維是閾值,個數爲全部訓練數據在此feature上的value的去重個數,從小到大排序的不重複值,用於對此節點的訓練數據在此feature上分裂時可選的feature value閾值) 6 hist.construct #根據訓練數據以及thresholds二維數組,初始化一個FeatureHistogram對象,用於構造總體數據的特徵統計直方圖,用於在根節點上進行分裂 7 初始化: 8 sum #二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是label之和,是全部訓練數據中在此feature上的value小於等於相應位置的threshold值(thresholds[i][j])的DataPoint的label之和,sum二維數組大小與thresholds數組相同 9 count #二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是個數,是全部訓練數據中在此feature上的value小於等於相應位置的threshold值(thresholds[i][j])的DataPoint的個數,count二維數組大小與thresholds數組相同 10 sampleToThresholdMap #二維數組,第一維是feature,下標是相應的features的下標,不是feature id;第二維是索引,是對應訓練數據samples[i][j]在特定feature上每一個訓練數據的value對應於其在thresholds數組中相應行的列索引位置 11 sumResponse #全部的訓練數據的label之和 12 sqSumResponse #全部的訓練數據的label的平方和 13 learn 14 初始化一個Ensemble對象ensemble 15 開始Gradient Boosting過程,即依次構造若干棵regression tree: 16 computePseudoResponses #計算本輪迭代中,每一個instance須要擬合的pseudo responses值(即梯度值,lambda) 17 根據LambdaMart的梯度計算公式進行計算 18 hist.update #根據本輪迭代中計算獲得的pseudo responses值(即梯度值,lambda),更新特徵統計直方圖,由於只改變了training data中每一個instance的label,而其餘值(如features)並未改變 19 初始化一棵regression tree(根據訓練數據和特徵統計直方圖) 20 rt.fit #用regression tree對訓練數據+本輪迭代中的pseudo responses值(即梯度值,lambda)進行擬合 21 將本輪迭代擬合產生的regression tree加入到ensembel對象中 22 updateTreeOutput #更新本輪迭代中擬合數據的regression tree的各個葉子節點的輸出 23 計算本輪迭代後(新regression tree已經加入到集成模型中),training data中各個instance的預測分:modelScores 24 computeModelScoreOnTraining #計算本輪迭代後,最新模型對於training data整體的排序評價分(例如NDCG) 25 計算本輪迭代後(新regression tree已經加入到集成模型中),validation data中各個instance的預測分:modelScoresOnValidation 26 computeModelScoreOnValidation #計算本輪迭代後,最新模型對於validation data整體的排序評價分(例如NDCG) 27 更新在validation data上的歷次各個模型的最優排序評價分:bestScoreOnValidationData,以及最優模型編號:bestModelOnValidation 28 若是在連續若干輪迭代中,模型在validation data上的排序評價分都沒有提升,則終止迭代 29 回滾到在驗證集上的最優模型 30 計算最優模型在training data和validation data上的排序評價分
下面是LambdaMart訓練過程的代碼,關鍵部分都有添加了詳細的註釋。ide
1. LambdaMART函數
1 package ciir.umass.edu.learning.tree; 2 import ciir.umass.edu.learning.DataPoint; 3 import ciir.umass.edu.learning.RankList; 4 import ciir.umass.edu.learning.Ranker; 5 import ciir.umass.edu.metric.MetricScorer; 6 import ciir.umass.edu.utilities.MergeSorter; 7 import ciir.umass.edu.utilities.MyThreadPool; 8 import ciir.umass.edu.utilities.RankLibError; 9 import ciir.umass.edu.utilities.SimpleMath; 10 import java.io.BufferedReader; 11 import java.io.StringReader; 12 import java.util.ArrayList; 13 import java.util.Arrays; 14 import java.util.List; 15 /** 16 * @author vdang 17 * 18 * This class implements LambdaMART. 19 * Q. Wu, C.J.C. Burges, K. Svore and J. Gao. Adapting Boosting for Information Retrieval Measures. 20 * Journal of Information Retrieval, 2007. 21 */ 22 public class LambdaMART extends Ranker { 23 //Parameters 24 public static int nTrees = 1000;//the number of trees 25 public static float learningRate = 0.1F;//or shrinkage 26 public static int nThreshold = 256; 27 public static int nRoundToStopEarly = 100;//If no performance gain on the *VALIDATION* data is observed in #rounds, stop the training process right away. 28 public static int nTreeLeaves = 10; 29 public static int minLeafSupport = 1; 30 31 //for debugging 32 public static int gcCycle = 100; 33 34 //Local variables 35 protected float[][] thresholds = null; 36 protected Ensemble ensemble = null; 37 protected double[] modelScores = null;//on training data 38 39 protected double[][] modelScoresOnValidation = null; 40 protected int bestModelOnValidation = Integer.MAX_VALUE-2; 41 42 //Training instances prepared for MART 43 protected DataPoint[] martSamples = null;//Need initializing only once 44 protected int[][] sortedIdx = null;//sorted list of samples in @martSamples by each feature -- Need initializing only once 45 protected FeatureHistogram hist = null; 46 protected double[] pseudoResponses = null;//different for each iteration 47 protected double[] weights = null;//different for each iteration 48 49 public LambdaMART() 50 { 51 } 52 public LambdaMART(List<RankList> samples, int[] features, MetricScorer scorer) 53 { 54 super(samples, features, scorer); 55 } 56 57 public void init() 58 { 59 PRINT("Initializing... "); 60 //initialize samples for MART 61 int dpCount = 0; 62 for(int i=0;i<samples.size();i++) 63 { 64 RankList rl = samples.get(i); 65 dpCount += rl.size(); 66 } 67 int current = 0; 68 martSamples = new DataPoint[dpCount]; 69 modelScores = new double[dpCount]; 70 pseudoResponses = new double[dpCount]; 71 weights = new double[dpCount]; 72 for(int i=0;i<samples.size();i++) 73 { 74 RankList rl = samples.get(i); 75 for(int j=0;j<rl.size();j++) 76 { 77 martSamples[current+j] = rl.get(j); 78 modelScores[current+j] = 0.0F; 79 pseudoResponses[current+j] = 0.0F; 80 weights[current+j] = 0; 81 } 82 current += rl.size(); 83 } 84 85 //sort (MART) samples by each feature so that we can quickly retrieve a sorted list of samples by any feature later on. 86 // 將樣本根據特徵排序,方便作樹的分裂時快速找出最優分裂點 87 sortedIdx = new int[features.length][]; 88 MyThreadPool p = MyThreadPool.getInstance(); 89 if(p.size() == 1)//single-thread 90 sortSamplesByFeature(0, features.length-1); 91 else//multi-thread 92 { 93 int[] partition = p.partition(features.length); 94 for(int i=0;i<partition.length-1;i++) 95 p.execute(new SortWorker(this, partition[i], partition[i+1]-1)); 96 p.await(); 97 } 98 99 //Create a table of candidate thresholds (for each feature). Later on, we will select the best tree split from these candidates // 建立存放候選閾值(分裂點)的表 100 thresholds = new float[features.length][]; 101 for(int f=0;f<features.length;f++) 102 { 103 //For this feature, keep track of the list of unique values and the max/min 104 List<Float> values = new ArrayList<Float>(); 105 float fmax = Float.NEGATIVE_INFINITY; 106 float fmin = Float.MAX_VALUE; 107 for(int i=0;i<martSamples.length;i++) 108 { 109 int k = sortedIdx[f][i];//get samples sorted with respect to this feature 110 float fv = martSamples[k].getFeatureValue(features[f]); 111 values.add(fv); 112 if(fmax < fv) 113 fmax = fv; 114 if(fmin > fv) 115 fmin = fv; 116 //skip all samples with the same feature value 117 int j=i+1; 118 while(j < martSamples.length) 119 { 120 if(martSamples[sortedIdx[f][j]].getFeatureValue(features[f]) > fv) 121 break; 122 j++; 123 } 124 i = j-1;//[i, j] gives the range of samples with the same feature value 125 } 126 127 if(values.size() <= nThreshold || nThreshold == -1) 128 { 129 thresholds[f] = new float[values.size()+1]; 130 for(int i=0;i<values.size();i++) 131 thresholds[f][i] = values.get(i); 132 thresholds[f][values.size()] = Float.MAX_VALUE; 133 } 134 else 135 { 136 float step = (Math.abs(fmax - fmin))/nThreshold; 137 thresholds[f] = new float[nThreshold+1]; 138 thresholds[f][0] = fmin; 139 for(int j=1;j<nThreshold;j++) 140 thresholds[f][j] = thresholds[f][j-1] + step; 141 thresholds[f][nThreshold] = Float.MAX_VALUE; 142 } 143 } 144 145 if(validationSamples != null) 146 { 147 modelScoresOnValidation = new double[validationSamples.size()][]; 148 for(int i=0;i<validationSamples.size();i++) 149 { 150 modelScoresOnValidation[i] = new double[validationSamples.get(i).size()]; 151 Arrays.fill(modelScoresOnValidation[i], 0); 152 } 153 } 154 155 //compute the feature histogram (this is used to speed up the procedure of finding the best tree split later on) 156 // 計算特徵直方圖,加速尋找分裂點 157 hist = new FeatureHistogram(); 158 hist.construct(martSamples, pseudoResponses, sortedIdx, features, thresholds); 159 //we no longer need the sorted indexes of samples 160 sortedIdx = null; 161 162 System.gc(); 163 PRINTLN("[Done]"); 164 } 165 public void learn() 166 { 167 ensemble = new Ensemble(); 168 169 PRINTLN("---------------------------------"); 170 PRINTLN("Training starts..."); 171 PRINTLN("---------------------------------"); 172 PRINTLN(new int[]{7, 9, 9}, new String[]{"#iter", scorer.name()+"-T", scorer.name()+"-V"}); 173 PRINTLN("---------------------------------"); 174 175 //Start the gradient boosting process 176 for(int m=0; m<nTrees; m++) 177 { 178 PRINT(new int[]{7}, new String[]{(m+1)+""}); 179 180 //Compute lambdas (which act as the "pseudo responses") 181 //Create training instances for MART: 182 // - Each document is a training sample 183 // - The lambda for this document serves as its training label 184 // 計算lambdas (pseudo responses) 185 computePseudoResponses(); 186 187 //update the histogram with these training labels (the feature histogram will be used to find the best tree split) 188 // 根據新的label更新特徵直方圖 189 hist.update(pseudoResponses); 190 191 //Fit a regression tree 192 // 迴歸決策樹 193 RegressionTree rt = new RegressionTree(nTreeLeaves, martSamples, pseudoResponses, hist, minLeafSupport); 194 rt.fit(); 195 196 //Add this tree to the ensemble (our model) 197 // 將新生成的樹加入模型 198 ensemble.add(rt, learningRate); 199 //update the outputs of the tree (with gamma computed using the Newton-Raphson method) 200 // 更新樹的輸出 201 updateTreeOutput(rt); 202 203 //Update the model's outputs on all training samples 204 // 更新全部訓練樣本的模型輸出 205 List<Split> leaves = rt.leaves(); 206 for(int i=0;i<leaves.size();i++) 207 { 208 Split s = leaves.get(i); 209 int[] idx = s.getSamples(); 210 for(int j=0;j<idx.length;j++) 211 modelScores[idx[j]] += learningRate * s.getOutput(); 212 } 213 //clear references to data that is no longer used 214 rt.clearSamples(); 215 216 //beg the garbage collector to work... 217 if(m % gcCycle == 0) 218 System.gc();//this call is expensive. We shouldn't do it too often. 219 //Evaluate the current model 220 // 評價模型 221 scoreOnTrainingData = computeModelScoreOnTraining(); 222 //**** NOTE **** 223 //The above function to evaluate the current model on the training data is equivalent to a single call: 224 // 225 // scoreOnTrainingData = scorer.score(rank(samples); 226 // 227 //However, this function is more efficient since it uses the cached outputs of the model (as opposed to re-evaluating the model 228 //on the entire training set). 229 230 PRINT(new int[]{9}, new String[]{SimpleMath.round(scoreOnTrainingData, 4) + ""}); 231 232 //Evaluate the current model on the validation data (if available) 233 if(validationSamples != null) 234 { 235 //Update the model's scores on all validation samples 236 for(int i=0;i<modelScoresOnValidation.length;i++) 237 for(int j=0;j<modelScoresOnValidation[i].length;j++) 238 modelScoresOnValidation[i][j] += learningRate * rt.eval(validationSamples.get(i).get(j)); 239 240 //again, equivalent to scoreOnValidation=scorer.score(rank(validationSamples)), but more efficient since we use the cached models' outputs 241 double score = computeModelScoreOnValidation(); 242 243 PRINT(new int[]{9}, new String[]{SimpleMath.round(score, 4) + ""}); 244 if(score > bestScoreOnValidationData) 245 { 246 bestScoreOnValidationData = score; 247 bestModelOnValidation = ensemble.treeCount()-1; 248 } 249 } 250 251 PRINTLN(""); 252 253 //Should we stop early? 254 // 檢驗是否提早結束 255 if(m - bestModelOnValidation > nRoundToStopEarly) 256 break; 257 } 258 259 //Rollback to the best model observed on the validation data 260 // 回滾到在驗證集上的最優模型 261 while(ensemble.treeCount() > bestModelOnValidation+1) 262 ensemble.remove(ensemble.treeCount()-1); 263 264 //Finishing up 265 scoreOnTrainingData = scorer.score(rank(samples)); 266 PRINTLN("---------------------------------"); 267 PRINTLN("Finished sucessfully."); 268 PRINTLN(scorer.name() + " on training data: " + SimpleMath.round(scoreOnTrainingData, 4)); 269 if(validationSamples != null) 270 { 271 bestScoreOnValidationData = scorer.score(rank(validationSamples)); 272 PRINTLN(scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4)); 273 } 274 PRINTLN("---------------------------------"); 275 } 276 public double eval(DataPoint dp) 277 { 278 return ensemble.eval(dp); 279 } 280 public Ranker createNew() 281 { 282 return new LambdaMART(); 283 } 284 public String toString() 285 { 286 return ensemble.toString(); 287 } 288 public String model() 289 { 290 String output = "## " + name() + "\n"; 291 output += "## No. of trees = " + nTrees + "\n"; 292 output += "## No. of leaves = " + nTreeLeaves + "\n"; 293 output += "## No. of threshold candidates = " + nThreshold + "\n"; 294 output += "## Learning rate = " + learningRate + "\n"; 295 output += "## Stop early = " + nRoundToStopEarly + "\n"; 296 output += "\n"; 297 output += toString(); 298 return output; 299 } 300 @Override 301 public void loadFromString(String fullText) 302 { 303 try { 304 String content = ""; 305 //String model = ""; 306 StringBuffer model = new StringBuffer (); 307 BufferedReader in = new BufferedReader(new StringReader(fullText)); 308 while((content = in.readLine()) != null) 309 { 310 content = content.trim(); 311 if(content.length() == 0) 312 continue; 313 if(content.indexOf("##")==0) 314 continue; 315 //actual model component 316 //model += content; 317 model.append (content); 318 } 319 in.close(); 320 //load the ensemble 321 ensemble = new Ensemble(model.toString()); 322 features = ensemble.getFeatures(); 323 } 324 catch(Exception ex) 325 { 326 throw RankLibError.create("Error in LambdaMART::load(): ", ex); 327 } 328 } 329 public void printParameters() 330 { 331 PRINTLN("No. of trees: " + nTrees); 332 PRINTLN("No. of leaves: " + nTreeLeaves); 333 PRINTLN("No. of threshold candidates: " + nThreshold); 334 PRINTLN("Min leaf support: " + minLeafSupport); 335 PRINTLN("Learning rate: " + learningRate); 336 PRINTLN("Stop early: " + nRoundToStopEarly + " rounds without performance gain on validation data"); 337 } 338 public String name() 339 { 340 return "LambdaMART"; 341 } 342 public Ensemble getEnsemble() 343 { 344 return ensemble; 345 } 346 347 protected void computePseudoResponses() 348 { 349 Arrays.fill(pseudoResponses, 0F); 350 Arrays.fill(weights, 0); 351 MyThreadPool p = MyThreadPool.getInstance(); 352 if(p.size() == 1)//single-thread 353 computePseudoResponses(0, samples.size()-1, 0); 354 else //multi-threading 355 { 356 List<LambdaComputationWorker> workers = new ArrayList<LambdaMART.LambdaComputationWorker>(); 357 //divide the entire dataset into chunks of equal size for each worker thread 358 int[] partition = p.partition(samples.size()); 359 int current = 0; 360 for(int i=0;i<partition.length-1;i++) 361 { 362 //execute the worker 363 LambdaComputationWorker wk = new LambdaComputationWorker(this, partition[i], partition[i+1]-1, current); 364 workers.add(wk);//keep it so we can get back results from it later on 365 p.execute(wk); 366 367 if(i < partition.length-2) 368 for(int j=partition[i]; j<=partition[i+1]-1;j++) 369 current += samples.get(j).size(); 370 } 371 372 //wait for all workers to complete before we move on to the next stage 373 p.await(); 374 } 375 } 376 protected void computePseudoResponses(int start, int end, int current) 377 { 378 int cutoff = scorer.getK(); 379 //compute the lambda for each document (a.k.a "pseudo response") 380 for(int i=start;i<=end;i++) 381 { 382 RankList orig = samples.get(i); 383 int[] idx = MergeSorter.sort(modelScores, current, current+orig.size()-1, false); 384 RankList rl = new RankList(orig, idx, current); 385 double[][] changes = scorer.swapChange(rl); 386 //NOTE: j, k are indices in the sorted (by modelScore) list, not the original 387 // ==> need to map back with idx[j] and idx[k] 388 for(int j=0;j<rl.size();j++) 389 { 390 DataPoint p1 = rl.get(j); 391 int mj = idx[j]; 392 for(int k=0;k<rl.size();k++) 393 { 394 if(j > cutoff && k > cutoff)//swaping these pair won't result in any change in target measures since they're below the cut-off point 395 break; 396 DataPoint p2 = rl.get(k); 397 int mk = idx[k]; 398 if(p1.getLabel() > p2.getLabel()) 399 { 400 double deltaNDCG = Math.abs(changes[j][k]); 401 if(deltaNDCG > 0) 402 { 403 double rho = 1.0 / (1 + Math.exp(modelScores[mj] - modelScores[mk])); 404 double lambda = rho * deltaNDCG; 405 pseudoResponses[mj] += lambda; 406 pseudoResponses[mk] -= lambda; 407 double delta = rho * (1.0 - rho) * deltaNDCG; 408 weights[mj] += delta; 409 weights[mk] += delta; 410 } 411 } 412 } 413 } 414 current += orig.size(); 415 } 416 } 417 protected void updateTreeOutput(RegressionTree rt) 418 { 419 List<Split> leaves = rt.leaves(); 420 for(int i=0;i<leaves.size();i++) 421 { 422 float s1 = 0F; 423 float s2 = 0F; 424 Split s = leaves.get(i); 425 int[] idx = s.getSamples(); 426 for(int j=0;j<idx.length;j++) 427 { 428 int k = idx[j]; 429 s1 += pseudoResponses[k]; 430 s2 += weights[k]; 431 } 432 if(s2 == 0) 433 s.setOutput(0); 434 else 435 s.setOutput(s1/s2); 436 } 437 } 438 protected int[] sortSamplesByFeature(DataPoint[] samples, int fid) 439 { 440 double[] score = new double[samples.length]; 441 for(int i=0;i<samples.length;i++) 442 score[i] = samples[i].getFeatureValue(fid); 443 int[] idx = MergeSorter.sort(score, true); 444 return idx; 445 } 446 /** 447 * This function is equivalent to the inherited function rank(...), but it uses the cached model's outputs instead of computing them from scratch. 448 * @param rankListIndex 449 * @param current 450 * @return 451 */ 452 protected RankList rank(int rankListIndex, int current) 453 { 454 RankList orig = samples.get(rankListIndex); 455 double[] scores = new double[orig.size()]; 456 for(int i=0;i<scores.length;i++) 457 scores[i] = modelScores[current+i]; 458 int[] idx = MergeSorter.sort(scores, false); 459 return new RankList(orig, idx); 460 } 461 protected float computeModelScoreOnTraining() 462 { 463 /*float s = 0; 464 int current = 0; 465 MyThreadPool p = MyThreadPool.getInstance(); 466 if(p.size() == 1)//single-thread 467 s = computeModelScoreOnTraining(0, samples.size()-1, current); 468 else 469 { 470 List<Worker> workers = new ArrayList<Worker>(); 471 //divide the entire dataset into chunks of equal size for each worker thread 472 int[] partition = p.partition(samples.size()); 473 for(int i=0;i<partition.length-1;i++) 474 { 475 //execute the worker 476 Worker wk = new Worker(this, partition[i], partition[i+1]-1, current); 477 workers.add(wk);//keep it so we can get back results from it later on 478 p.execute(wk); 479 480 if(i < partition.length-2) 481 for(int j=partition[i]; j<=partition[i+1]-1;j++) 482 current += samples.get(j).size(); 483 } 484 //wait for all workers to complete before we move on to the next stage 485 p.await(); 486 for(int i=0;i<workers.size();i++) 487 s += workers.get(i).score; 488 }*/ 489 float s = computeModelScoreOnTraining(0, samples.size()-1, 0); 490 s = s / samples.size(); 491 return s; 492 } 493 protected float computeModelScoreOnTraining(int start, int end, int current) 494 { 495 float s = 0; 496 int c = current; 497 for(int i=start;i<=end;i++) 498 { 499 s += scorer.score(rank(i, c)); 500 c += samples.get(i).size(); 501 } 502 return s; 503 } 504 protected float computeModelScoreOnValidation() 505 { 506 /*float score = 0; 507 MyThreadPool p = MyThreadPool.getInstance(); 508 if(p.size() == 1)//single-thread 509 score = computeModelScoreOnValidation(0, validationSamples.size()-1); 510 else 511 { 512 List<Worker> workers = new ArrayList<Worker>(); 513 //divide the entire dataset into chunks of equal size for each worker thread 514 int[] partition = p.partition(validationSamples.size()); 515 for(int i=0;i<partition.length-1;i++) 516 { 517 //execute the worker 518 Worker wk = new Worker(this, partition[i], partition[i+1]-1); 519 workers.add(wk);//keep it so we can get back results from it later on 520 p.execute(wk); 521 } 522 //wait for all workers to complete before we move on to the next stage 523 p.await(); 524 for(int i=0;i<workers.size();i++) 525 score += workers.get(i).score; 526 }*/ 527 float score = computeModelScoreOnValidation(0, validationSamples.size()-1); 528 return score/validationSamples.size(); 529 } 530 protected float computeModelScoreOnValidation(int start, int end) 531 { 532 float score = 0; 533 for(int i=start;i<=end;i++) 534 { 535 int[] idx = MergeSorter.sort(modelScoresOnValidation[i], false); 536 score += scorer.score(new RankList(validationSamples.get(i), idx)); 537 } 538 return score; 539 } 540 541 protected void sortSamplesByFeature(int fStart, int fEnd) 542 { 543 for(int i=fStart;i<=fEnd; i++) 544 sortedIdx[i] = sortSamplesByFeature(martSamples, features[i]); 545 } 546 //For multi-threading processing 547 class SortWorker implements Runnable { 548 LambdaMART ranker = null; 549 int start = -1; 550 int end = -1; 551 SortWorker(LambdaMART ranker, int start, int end) 552 { 553 this.ranker = ranker; 554 this.start = start; 555 this.end = end; 556 } 557 public void run() 558 { 559 ranker.sortSamplesByFeature(start, end); 560 } 561 } 562 class LambdaComputationWorker implements Runnable { 563 LambdaMART ranker = null; 564 int rlStart = -1; 565 int rlEnd = -1; 566 int martStart = -1; 567 LambdaComputationWorker(LambdaMART ranker, int rlStart, int rlEnd, int martStart) 568 { 569 this.ranker = ranker; 570 this.rlStart = rlStart; 571 this.rlEnd = rlEnd; 572 this.martStart = martStart; 573 } 574 public void run() 575 { 576 ranker.computePseudoResponses(rlStart, rlEnd, martStart); 577 } 578 } 579 class Worker implements Runnable { 580 LambdaMART ranker = null; 581 int rlStart = -1; 582 int rlEnd = -1; 583 int martStart = -1; 584 int type = -1; 585 586 //compute score on validation 587 float score = 0; 588 589 Worker(LambdaMART ranker, int rlStart, int rlEnd) 590 { 591 type = 3; 592 this.ranker = ranker; 593 this.rlStart = rlStart; 594 this.rlEnd = rlEnd; 595 } 596 Worker(LambdaMART ranker, int rlStart, int rlEnd, int martStart) 597 { 598 type = 4; 599 this.ranker = ranker; 600 this.rlStart = rlStart; 601 this.rlEnd = rlEnd; 602 this.martStart = martStart; 603 } 604 public void run() 605 { 606 if(type == 4) 607 score = ranker.computeModelScoreOnTraining(rlStart, rlEnd, martStart); 608 else if(type == 3) 609 score = ranker.computeModelScoreOnValidation(rlStart, rlEnd); 610 } 611 } 612 }
版權聲明:
本文由笨兔勿應全部,發佈於http://www.cnblogs.com/bentuwuying。若是轉載,請註明出處,在未經做者贊成下將本文用於商業用途,將追究其法律責任。