Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習算法平臺,是業界首個同時支持批式算法、流式算法的機器學習平臺。本文將帶領你們來分析Alink中 Quantile 的實現。html
由於Alink的公開資料太少,因此如下均爲自行揣測,確定會有疏漏錯誤,但願你們指出,我會隨時更新。java
本文原因是由於想分析GBDT,發現GBDT涉及到Quantile的使用,因此只能先分析Quantile 。算法
離散化:就是把無限空間中有限的個體映射到有限的空間中(分箱處理)。數據離散化操做大可能是針對連續數據進行的,處理以後的數據值域分佈將從連續屬性變爲離散屬性。app
離散化方式會影響後續數據建模和應用效果:機器學習
連續數據的離散化結果能夠分爲兩類:ide
分位數(Quantile),亦稱分位點,是指將一個隨機變量的機率分佈範圍分爲幾個等份的數值點,經常使用的有中位數(即二分位數)、四分位數、百分位數等。函數
假若有1000個數字(正數),這些數字的5%, 30%, 50%, 70%, 99%分位數分別是 [3.0,5.0,6.0,9.0,12.0],這代表學習
這就是分位數的統計學理解。ui
所以求解某一組數字中某個數的分位數,只須要將該組數字進行排序,而後再統計小於等於該數的個數,除以總的數字個數便可。this
肯定p分位數位置的兩種方法
這裏咱們用四分位數作進一步說明。
四分位數 概念:把給定的亂序數值由小到大排列並分紅四等份,處於三個分割點位置的數值就是四分位數。
第1四分位數 (Q1),又稱「較小四分位數」,等於該樣本中全部數值由小到大排列後第25%的數字。
第2四分位數 (Q2),又稱「中位數」,等於該樣本中全部數值由小到大排列後第50%的數字。
第3四分位數 (Q3),又稱「較大四分位數」,等於該樣本中全部數值由小到大排列後第75%的數字。
四分位距(InterQuartile Range, IQR)= 第3四分位數與第1四分位數的差距。
Alink中完成分位數功能的是QuantileDiscretizer
。QuantileDiscretizer
輸入連續的特徵列,輸出分箱的類別特徵。
numBuckets
(桶數目)來指定的。 箱的範圍是經過使用近似算法來獲得的。本文示例代碼以下。
public class QuantileDiscretizerExample { public static void main(String[] args) throws Exception { NumSeqSourceBatchOp numSeqSourceBatchOp = new NumSeqSourceBatchOp(1001, 2000, "col0"); // 就是把1001 ~ 2000 這個連續數值分段 Pipeline pipeline = new Pipeline() .add(new QuantileDiscretizer() .setNumBuckets(6) // 指定分箱數數目 .setSelectedCols(new String[]{"col0"})); List<Row> result = pipeline.fit(numSeqSourceBatchOp).transform(numSeqSourceBatchOp).collect(); System.out.println(result); } }
輸出
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ..... 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ..... 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]
咱們首先給出整體邏輯圖例
-------------------------------- 準備階段 -------------------------------- │ │ │ ┌───────────────────┐ │ getSelectedCols │ 獲取須要分位的列名字 └───────────────────┘ │ │ │ ┌─────────────────────┐ │ quantileNum │ 獲取分箱數 └─────────────────────┘ │ │ │ ┌──────────────────────┐ │ Preprocessing.select │ 從輸入中根據列名字select出數據 └──────────────────────┘ │ │ │ -------------------------------- 預處理階段 -------------------------------- │ │ │ ┌──────────────────────┐ │ quantile │ 後續步驟 就是 計算分位數 └──────────────────────┘ │ │ │ ┌────────────────────────────────┐ │ countElementsPerPartition │ 在每個partition中獲取該分區的全部元素個數 └────────────────────────────────┘ │ <task id, count in this task> │ │ ┌──────────────────────┐ │ sum(1) │ 這裏對第二個參數,即"count in this task"進行累積,得出全部元素的個數 └──────────────────────┘ │ │ │ ┌──────────────────────┐ │ map │ 取出全部元素個數,cnt在後續會使用 └──────────────────────┘ │ │ │ │ ┌──────────────────────┐ │ missingCount │ 分區查找應選的列中,有哪些數據沒有被查到,好比zeroAsMissing, null, isNaN └──────────────────────┘ │ │ │ ┌────────────────┐ │ mapPartition │ 把輸入數據Row打散,對於Row中的子元素按照Row內順序一一發送出來 └────────────────┘ │ <idx in row, item in row>, 即<row中第幾個元素,元素> │ │ ┌──────────────┐ │ pSort │ 將flatten數據進行排序 └──────────────┘ │ 返回的是二元組 │ f0: dataset which is indexed by partition id │ f1: dataset which has partition id and count │ │ -------------------------------- 計算階段 -------------------------------- │ │ │ ┌─────────────────┐ │ MultiQuantile │ 後續都是具體計算步驟 └─────────────────┘ │ │ │ ┌─────────────────┐ │ open │ 從廣播中獲取變量,初步處理counts(排序),totalCnt,missingCounts(排序) └─────────────────┘ │ │ │ ┌─────────────────┐ │ mapPartition │ 具體計算 └─────────────────┘ │ │ │ ┌─────────────────┐ │ groupBy(0) │ 依據 列idx 分組 └─────────────────┘ │ │ │ ┌─────────────────┐ │ reduceGroup │ 歸併排序 └─────────────────┘ │set(Tuple2<column idx, 真實數據值>) │ │ -------------------------------- 序列化模型 -------------------------------- │ │ │ ┌──────────────┐ │ reduceGroup │ 分組歸併 └──────────────┘ │ │ │ ┌─────────────────┐ │ SerializeModel │ 序列化模型 └─────────────────┘
下面圖片是爲了在手機上縮放適配展現。
QuantileDiscretizerTrainBatchOp.linkFrom以下:
public QuantileDiscretizerTrainBatchOp linkFrom(BatchOperator<?>... inputs) { BatchOperator<?> in = checkAndGetFirst(inputs); // 示例中設置了 .setSelectedCols(new String[]{"col0"}));, 因此這裏 quantileColNames 的數值是"col0 String[] quantileColNames = getSelectedCols(); int[] quantileNum = null; // 示例中設置了 .setNumBuckets(6),因此這裏 quantileNum 是 quantileNum = {int[1]@2705} 0 = 6 if (getParams().contains(QuantileDiscretizerTrainParams.NUM_BUCKETS)) { quantileNum = new int[quantileColNames.length]; Arrays.fill(quantileNum, getNumBuckets()); } else { quantileNum = Arrays.stream(getNumBucketsArray()).mapToInt(Integer::intValue).toArray(); } /* filter the selected column from input */ // 獲取了 選擇的列 "col0" DataSet<Row> input = Preprocessing.select(in, quantileColNames).getDataSet(); // 計算分位數 DataSet<Row> quantile = quantile( input, quantileNum, getParams().get(HasRoundMode.ROUND_MODE), getParams().get(Preprocessing.ZERO_AS_MISSING) ); // 序列化模型 quantile = quantile.reduceGroup( new SerializeModel( getParams(), quantileColNames, TableUtil.findColTypesWithAssertAndHint(in.getSchema(), quantileColNames), BinTypes.BinDivideType.QUANTILE ) ); /* set output */ setOutput(quantile, new QuantileDiscretizerModelDataConverter().getModelSchema()); return this; }
其整體邏輯以下:
訓練是經過 quantile 完成的,大體包含如下步驟。
具體以下
public static DataSet<Row> quantile( DataSet<Row> input, final int[] quantileNum, final HasRoundMode.RoundMode roundMode, final boolean zeroAsMissing) { /* instance count of dataset */ // countElementsPerPartition 的做用是:在每個partition中獲取該分區的全部元素個數,返回<task id, count in this task>。 DataSet<Long> cnt = DataSetUtils .countElementsPerPartition(input) .sum(1) // 這裏對第二個參數,即"count in this task"進行累積,得出全部元素的個數。 .map(new MapFunction<Tuple2<Integer, Long>, Long>() { @Override public Long map(Tuple2<Integer, Long> value) throws Exception { return value.f1; // 取出全部元素個數 } }); // cnt在後續會使用 /* missing count of columns */ // 會查找應選的列中,有哪些數據沒有被查到,從代碼看,是zeroAsMissing, null, isNaN這幾種狀況 DataSet<Tuple2<Integer, Long>> missingCount = input .mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Long>>() { public void mapPartition(Iterable<Row> values, Collector<Tuple2<Integer, Long>> out) { StreamSupport.stream(values.spliterator(), false) .flatMap(x -> { long[] counts = new long[x.getArity()]; Arrays.fill(counts, 0L); // 若是發現有數據沒有查到,就增長counts for (int i = 0; i < x.getArity(); ++i) { if (x.getField(i) == null || (zeroAsMissing && ((Number) x.getField(i)).doubleValue() == 0.0) || Double.isNaN(((Number)x.getField(i)).doubleValue())) { counts[i]++; } } return IntStream.range(0, x.getArity()) .mapToObj(y -> Tuple2.of(y, counts[y])); }) .collect(Collectors.groupingBy( x -> x.f0, Collectors.mapping(x -> x.f1, Collectors.reducing((a, b) -> a + b)) ) ) .entrySet() .stream() .map(x -> Tuple2.of(x.getKey(), x.getValue().get())) .forEach(out::collect); } }) .groupBy(0) //按第一個元素分組 .reduce(new RichReduceFunction<Tuple2<Integer, Long>>() { @Override public Tuple2<Integer, Long> reduce(Tuple2<Integer, Long> value1, Tuple2<Integer, Long> value2) { return Tuple2.of(value1.f0, value1.f1 + value2.f1); //累積求和 } }); /* flatten dataset to 1d */ // 把輸入數據打散。 DataSet<PairComparable> flatten = input .mapPartition(new RichMapPartitionFunction<Row, PairComparable>() { PairComparable pairBuff; public void mapPartition(Iterable<Row> values, Collector<PairComparable> out) { for (Row value : values) { // 遍歷分區內全部輸入元素 for (int i = 0; i < value.getArity(); ++i) { // 若是輸入元素Row自己包含多個子元素 pairBuff.first = i; // 則對於這些子元素按照Row內順序一一發送出來,這就作到了把Row類型給flatten了 if (value.getField(i) == null || (zeroAsMissing && ((Number) value.getField(i)).doubleValue() == 0.0) || Double.isNaN(((Number)value.getField(i)).doubleValue())) { pairBuff.second = null; } else { pairBuff.second = (Number) value.getField(i); } out.collect(pairBuff); // 返回<idx in row, item in row>, 即<row中第幾個元素,元素> } } } }); /* sort data */ // 將flatten數據進行排序,pSort是大規模分區排序,此時尚未分類 // pSort返回的是二元組,f0: dataset which is indexed by partition id, f1: dataset which has partition id and count. Tuple2<DataSet<PairComparable>, DataSet<Tuple2<Integer, Long>>> sortedData = SortUtilsNext.pSort(flatten); /* calculate quantile */ return sortedData.f0 //f0: dataset which is indexed by partition id .mapPartition(new MultiQuantile(quantileNum, roundMode)) .withBroadcastSet(sortedData.f1, "counts") //f1: dataset which has partition id and count .withBroadcastSet(cnt, "totalCnt") .withBroadcastSet(missingCount, "missingCounts") .groupBy(0) // 依據 列idx 分組 .reduceGroup(new RichGroupReduceFunction<Tuple2<Integer, Number>, Row>() { @Override public void reduce(Iterable<Tuple2<Integer, Number>> values, Collector<Row> out) { TreeSet<Number> set = new TreeSet<>(new Comparator<Number>() { @Override public int compare(Number o1, Number o2) { return SortUtils.OBJECT_COMPARATOR.compare(o1, o2); } }); int id = -1; for (Tuple2<Integer, Number> val : values) { // Tuple2<column idx, 數據> id = val.f0; set.add(val.f1); } // runtime變量 set = {TreeSet@9379} size = 5 0 = {Long@9389} 167 // 就是第 0 列的第一段 idx 1 = {Long@9392} 333 // 就是第 0 列的第二段 idx 2 = {Long@9393} 500 3 = {Long@9394} 667 4 = {Long@9382} 833 out.collect(Row.of(id, set.toArray(new Number[0]))); } }); }
下面會對幾個重點函數作說明。
countElementsPerPartition 的做用是:在每個partition中獲取該分區的全部元素個數。
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) { return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() { @Override public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception { long counter = 0; for (T value : values) { counter++; // 在每個partition中獲取該分區的全部元素個數 } out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter)); } }); }
MultiQuantile用來計算具體的分位點。
open函數中會從廣播中獲取變量,初步處理counts(排序),totalCnt,missingCounts(排序)等等。
mapPartition函數則作具體計算,大體步驟以下:
具體代碼是:
public static class MultiQuantile extends RichMapPartitionFunction<PairComparable, Tuple2<Integer, Number>> { private List<Tuple2<Integer, Long>> counts; private List<Tuple2<Integer, Long>> missingCounts; private long totalCnt = 0; private int[] quantileNum; private HasRoundMode.RoundMode roundType; private int taskId; @Override public void open(Configuration parameters) throws Exception { // 從廣播中獲取變量,初步處理counts(排序),totalCnt,missingCounts(排序)。 // 以前設置廣播變量.withBroadcastSet(sortedData.f1, "counts"),其中 f1 的格式是: dataset which has partition id and count,因此就是用 partition id來排序 this.counts = getRuntimeContext().getBroadcastVariableWithInitializer( "counts", new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() { @Override public List<Tuple2<Integer, Long>> initializeBroadcastVariable( Iterable<Tuple2<Integer, Long>> data) { ArrayList<Tuple2<Integer, Long>> sortedData = new ArrayList<>(); for (Tuple2<Integer, Long> datum : data) { sortedData.add(datum); } //排序 sortedData.sort(Comparator.comparing(o -> o.f0)); // runtime的數據以下,本機有4核,因此數據分爲4個 partition,每一個partition的數據分別爲251,250,250,250 sortedData = {ArrayList@9347} size = 4 0 = {Tuple2@9350} "(0,251)" // partition 0, 數據個數是251 1 = {Tuple2@9351} "(1,250)" 2 = {Tuple2@9352} "(2,250)" 3 = {Tuple2@9353} "(3,250)" return sortedData; } }); this.totalCnt = getRuntimeContext().getBroadcastVariableWithInitializer("totalCnt", new BroadcastVariableInitializer<Long, Long>() { @Override public Long initializeBroadcastVariable(Iterable<Long> data) { return data.iterator().next(); } }); this.missingCounts = getRuntimeContext().getBroadcastVariableWithInitializer( "missingCounts", new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() { @Override public List<Tuple2<Integer, Long>> initializeBroadcastVariable( Iterable<Tuple2<Integer, Long>> data) { return StreamSupport.stream(data.spliterator(), false) .sorted(Comparator.comparing(o -> o.f0)) .collect(Collectors.toList()); } } ); taskId = getRuntimeContext().getIndexOfThisSubtask(); // runtime的數據以下 this = {QuantileDiscretizerTrainBatchOp$MultiQuantile@9348} counts = {ArrayList@9347} size = 4 0 = {Tuple2@9350} "(0,251)" 1 = {Tuple2@9351} "(1,250)" 2 = {Tuple2@9352} "(2,250)" 3 = {Tuple2@9353} "(3,250)" missingCounts = {ArrayList@9375} size = 1 0 = {Tuple2@9381} "(0,0)" totalCnt = 1001 quantileNum = {int[1]@9376} 0 = 6 roundType = {HasRoundMode$RoundMode@9377} "ROUND" taskId = 2 } @Override public void mapPartition(Iterable<PairComparable> values, Collector<Tuple2<Integer, Number>> out) throws Exception { long start = 0; long end; int curListIndex = -1; int size = counts.size(); // 分紅4份,因此這裏是4 for (int i = 0; i < size; ++i) { int curId = counts.get(i).f0; // 取出輸入元素中的 partition id if (curId == taskId) { curListIndex = i; // 當前 task 對應哪一個 partition id break; // 到了當前task,就能夠跳出了 } start += counts.get(i).f1; // 累積,獲得當前 task 的起始位置,即1000個數據中從哪一個數據開始計算 } // 根據 taskId 從counts中獲得了本 task 應該處理哪些數據,即數據的start,end位置 // 本 partition 是 0,其中有251個數據 end = start + counts.get(curListIndex).f1; // end = 起始位置 + 此partition的數據個數 ArrayList<PairComparable> allRows = new ArrayList<>((int) (end - start)); for (PairComparable value : values) { allRows.add(value); // value 可認爲是 <partition id, 真實數據> } allRows.sort(Comparator.naturalOrder()); // runtime變量 start = 0 curListIndex = 0 size = 4 end = 251 allRows = {ArrayList@9406} size = 251 0 = {PairComparable@9408} first = {Integer@9397} 0 second = {Long@9434} 0 1 = {PairComparable@9409} first = {Integer@9397} 0 second = {Long@9435} 1 2 = {PairComparable@9410} first = {Integer@9397} 0 second = {Long@9439} 2 ...... // size = ((251 - 1) / 1001 - 0 / 1001) + 1 = 1 size = (int) ((end - 1) / totalCnt - start / totalCnt) + 1; int localStart = 0; for (int i = 0; i < size; ++i) { int fIdx = (int) (start / totalCnt + i); int subStart = 0; int subEnd = (int) totalCnt; if (i == 0) { subStart = (int) (start % totalCnt); // 0 } if (i == size - 1) { subEnd = (int) (end % totalCnt == 0 ? totalCnt : end % totalCnt); // 251 } if (totalCnt - missingCounts.get(fIdx).f1 == 0) { localStart += subEnd - subStart; continue; } QIndex qIndex = new QIndex( totalCnt - missingCounts.get(fIdx).f1, quantileNum[fIdx], roundType); // runtime變量 qIndex = {QuantileDiscretizerTrainBatchOp$QIndex@9548} totalCount = 1001.0 q1 = 0.16666666666666666 roundMode = {HasRoundMode$RoundMode@9377} "ROUND" // 遍歷,一直到分箱數。 for (int j = 1; j < quantileNum[fIdx]; ++j) { // 獲取每一個分箱的index long index = qIndex.genIndex(j); // j = 1 ---> index = 167,就是把 1001 個分爲6段,第一段終點是167 //對應本 task = 0,subStart = 0,subEnd = 251。則index = 167,直接從allRows獲取第167個,數值是 1168。由於連續區域是 1001 ~ 2000,因此第167個對應數值就是1168 //若是本 task = 1,subStart = 251,subEnd = 501。則index = 333,直接從allRows獲取第 (333 + 0 - 251)= 第 82 個,獲取其中的數值。這裏由於數值區域是 1001 ~ 2000, 因此數值是1334。 if (index >= subStart && index < subEnd) { // idx剛恰好在本分區的數據中 PairComparable pairComparable = allRows.get( (int) (index + localStart - subStart)); // // runtime變量 pairComparable = {PairComparable@9581} first = {Integer@9507} 0 // first是column idx second = {Long@9584} 167 // 真實數據 out.collect(Tuple2.of(pairComparable.first, pairComparable.second)); } } localStart += subEnd - subStart; } } }
其中 QIndex 是本文關鍵所在,就是具體計算分位數。
public static class QIndex { private double totalCount; private double q1; private HasRoundMode.RoundMode roundMode; public QIndex(double totalCount, int quantileNum, HasRoundMode.RoundMode type) { this.totalCount = totalCount; // 1001,全部元素的個數 this.q1 = 1.0 / (double) quantileNum; // 1.0 / 6 = 16666666666666666。quantileNum是分紅幾段,q1就是每一段的大小。若是分紅6段,則每一段的大小是1/6 this.roundMode = type; } public long genIndex(int k) { // 假設仍是6段,則若是取第一段,則k=1,其index爲 (1/6 * (1001 - 1) * 1) = 167 return roundMode.calc(this.q1 * (this.totalCount - 1.0) * (double) k); } }
輸出模型是經過 reduceGroup 調用 SerializeModel 來完成。
具體邏輯是:
// 序列化模型 quantile = quantile.reduceGroup( new SerializeModel( getParams(), quantileColNames, TableUtil.findColTypesWithAssertAndHint(in.getSchema(), quantileColNames), BinTypes.BinDivideType.QUANTILE ) );
SerializeModel 的具體實現是:
public static class SerializeModel implements GroupReduceFunction<Row, Row> { private Params meta; private String[] colNames; private TypeInformation<?>[] colTypes; private BinTypes.BinDivideType binDivideType; @Override public void reduce(Iterable<Row> values, Collector<Row> out) throws Exception { Map<String, FeatureBorder> m = new HashMap<>(); for (Row val : values) { int index = (int) val.getField(0); Number[] splits = (Number[]) val.getField(1); m.put( colNames[index], QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder( colNames[index], colTypes[index], splits, meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN), binDivideType ) ); } for (int i = 0; i < colNames.length; ++i) { if (m.containsKey(colNames[i])) { continue; } m.put( colNames[i], QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder( colNames[i], colTypes[i], null, meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN), binDivideType ) ); } QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter(m, meta); model.save(model, out); } }
這裏用到了 FeatureBorder 類。
數據分箱是按照某種規則將數據進行分類。就像能夠將水果按照大小進行分類,售賣不一樣的價格同樣。
FeatureBorder 就是專門爲了 Featureborder for binning, discrete Featureborder and continuous Featureborder。
咱們可以看出來,該分箱對應的列名,index,各個分割點。
m = {HashMap@9380} size = 1 "col0" -> {FeatureBorder@9438} "{"binDivideType":"QUANTILE","featureName":"col0","bin":{"NORM":[{"index":0},{"index":1},{"index":2},{"index":3},{"index":4},{"index":5}],"NULL":{"index":6}},"featureType":"BIGINT","splitsArray":[1168,1334,1501,1667,1834],"isLeftOpen":true,"binCount":6}"
預測是在 QuantileDiscretizerModelMapper 中完成的。
模型數據是
model = {QuantileDiscretizerModelDataConverter@9582} meta = {Params@9670} "Params {selectedCols=["col0"], version="v2", numBuckets=6}" data = {HashMap@9584} size = 1 "col0" -> {FeatureBorder@9676} "{"binDivideType":"QUANTILE","featureName":"col0","bin":{"NORM":[{"index":0},{"index":1},{"index":2},{"index":3},{"index":4},{"index":5}],"NULL":{"index":6}},"featureType":"BIGINT","splitsArray":[1168,1334,1501,1667,1834],"isLeftOpen":true,"binCount":6}"
loadModel會完成加載。
@Override public void loadModel(List<Row> modelRows) { QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter(); model.load(modelRows); for (int i = 0; i < mapperBuilder.paramsBuilder.selectedCols.length; i++) { FeatureBorder border = model.data.get(mapperBuilder.paramsBuilder.selectedCols[i]); List<Bin.BaseBin> norm = border.bin.normBins; int size = norm.size(); Long maxIndex = norm.get(0).getIndex(); Long lastIndex = norm.get(size - 1).getIndex(); for (int j = 0; j < norm.size(); ++j) { if (maxIndex < norm.get(j).getIndex()) { maxIndex = norm.get(j).getIndex(); } } long maxIndexWithNull = Math.max(maxIndex, border.bin.nullBin.getIndex()); switch (mapperBuilder.paramsBuilder.handleInvalidStrategy) { case KEEP: mapperBuilder.vectorSize.put(i, maxIndexWithNull + 1); break; case SKIP: case ERROR: mapperBuilder.vectorSize.put(i, maxIndex + 1); break; default: throw new UnsupportedOperationException("Unsupported now."); } if (mapperBuilder.paramsBuilder.dropLast) { mapperBuilder.dropIndex.put(i, lastIndex); } mapperBuilder.discretizers[i] = createQuantileDiscretizer(border, model.meta); } mapperBuilder.setAssembledVectorSize(); }
加載中,最後調用 createQuantileDiscretizer 生成 LongQuantileDiscretizer。這就是針對Long類型的離散器。
public static class LongQuantileDiscretizer implements NumericQuantileDiscretizer { long[] bounds; boolean isLeftOpen; int[] boundIndex; int nullIndex; boolean zeroAsMissing; @Override public int findIndex(Object number) { if (number == null) { return nullIndex; } long lVal = ((Number) number).longValue(); if (isMissing(lVal, zeroAsMissing)) { return nullIndex; } int hit = Arrays.binarySearch(bounds, lVal); if (isLeftOpen) { hit = hit >= 0 ? hit - 1 : -hit - 2; } else { hit = hit >= 0 ? hit : -hit - 2; } return boundIndex[hit]; } }
其數值以下:
this = {QuantileDiscretizerModelMapper$LongQuantileDiscretizer@9768} bounds = {long[7]@9757} 0 = -9223372036854775807 1 = 1168 2 = 1334 3 = 1501 4 = 1667 5 = 1834 6 = 9223372036854775807 isLeftOpen = true boundIndex = {int[7]@9743} 0 = 0 // -9223372036854775807 ~ 1168 之間對應的最終分箱離散值是 0 1 = 1 2 = 2 3 = 3 4 = 4 5 = 5 6 = 5 // 1834 ~ 9223372036854775807 之間對應的最終分箱離散值是 5 nullIndex = 6 zeroAsMissing = false
預測 QuantileDiscretizerModelMapper 的 DiscretizerMapperBuilder 完成。
Row map(Row row){ // 這裏的 row 舉例是: row = {Row@9743} "1003" for (int i = 0; i < paramsBuilder.selectedCols.length; i++) { int colIdxInData = selectedColIndicesInData[i]; Object val = row.getField(colIdxInData); int foundIndex = discretizers[i].findIndex(val); // 找到 1003對應的index,就是調用Discretizer完成,這裏找到 foundIndex 是0 predictIndices[i] = (long) foundIndex; } return paramsBuilder.outputColsHelper.getResultRow( row, setResultRow( predictIndices, paramsBuilder.encode, dropIndex, vectorSize, paramsBuilder.dropLast, assembledVectorSize) // 最後返回離散值是0 ); } this = {QuantileDiscretizerModelMapper$DiscretizerMapperBuilder@9744} paramsBuilder = {QuantileDiscretizerModelMapper$DiscretizerParamsBuilder@9752} selectedColIndicesInData = {int[1]@9754} vectorSize = {HashMap@9758} size = 1 dropIndex = {HashMap@9759} size = 1 assembledVectorSize = {Integer@9760} 6 discretizers = {QuantileDiscretizerModelMapper$NumericQuantileDiscretizer[1]@9761} 0 = {QuantileDiscretizerModelMapper$LongQuantileDiscretizer@9768} bounds = {long[7]@9776} isLeftOpen = true boundIndex = {int[7]@9777} nullIndex = 6 zeroAsMissing = false predictIndices = {Long[1]@9763}
Spark QuantileDiscretizer 分位數離散器