Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習算法平臺,是業界首個同時支持批式算法、流式算法的機器學習平臺。二分類評估是對二分類算法的預測結果進行效果評估。本文將剖析Alink中對應代碼實現。html
若是對本文某些概念有疑惑,能夠參見以前文章 [白話解析] 經過實例來梳理概念 :準確率 (Accuracy)、精準率(Precision)、召回率(Recall) 和 F值(F-Measure)java
public class EvalBinaryClassExample { AlgoOperator getData(boolean isBatch) { Row[] rows = new Row[]{ Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"), Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"), Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"), Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"), Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}") }; String[] schema = new String[]{"label", "detailInput"}; if (isBatch) { return new MemSourceBatchOp(rows, schema); } else { return new MemSourceStreamOp(rows, schema); } } public static void main(String[] args) throws Exception { EvalBinaryClassExample test = new EvalBinaryClassExample(); BatchOperator batchData = (BatchOperator) test.getData(true); BinaryClassMetrics metrics = new EvalBinaryClassBatchOp() .setLabelCol("label") .setPredictionDetailCol("detailInput") .linkFrom(batchData) .collectMetrics(); System.out.println("RocCurve:" + metrics.getRocCurve()); System.out.println("AUC:" + metrics.getAuc()); System.out.println("KS:" + metrics.getKs()); System.out.println("PRC:" + metrics.getPrc()); System.out.println("Accuracy:" + metrics.getAccuracy()); System.out.println("Macro Precision:" + metrics.getMacroPrecision()); System.out.println("Micro Recall:" + metrics.getMicroRecall()); System.out.println("Weighted Sensitivity:" + metrics.getWeightedSensitivity()); } }
程序輸出python
RocCurve:([0.0, 0.0, 0.0, 0.5, 0.5, 1.0, 1.0],[0.0, 0.3333333333333333, 0.6666666666666666, 0.6666666666666666, 1.0, 1.0, 1.0]) AUC:0.8333333333333333 KS:0.6666666666666666 PRC:0.9027777777777777 Accuracy:0.6 Macro Precision:0.3 Micro Recall:0.6 Weighted Sensitivity:0.6
在 Alink 中,二分類評估有批處理,流處理兩種實現,下面一一爲你們介紹( Alink 複雜之一在於大量精細的數據結構,因此下文會大量打印程序中變量以便你們理解)。算法
把 [0,1] 分紅假設 100000個桶(bin)。因此獲得positiveBin / negativeBin 兩個100000的數組。json
根據輸入給positiveBin / negativeBin賦值。positiveBin就是 TP + FP,negativeBin就是 TN + FN。這些是後續計算的基礎。windows
遍歷bins中每個有意義的點,計算出totalTrue和totalFalse,而且在每個點上計算該點的混淆矩陣,tpr,以及rocCurve,recallPrecisionCurve,liftChart在該點對應的數據;數組
依據曲線內容計算而且存儲 AUC/PRC/KS數據結構
具體後續還有詳細調用關係綜述。app
EvalBinaryClassBatchOp是二分類評估的實現,功能是計算二分類的評估指標(evaluation metrics)。dom
輸入有兩種:
咱們例子中 "prefix1"
就是 label,"{\"prefix1\": 0.9, \"prefix0\": 0.1}"
就是 predDetail
Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}")
具體類摘錄以下:
public class EvalBinaryClassBatchOp extends BaseEvalClassBatchOp<EvalBinaryClassBatchOp> implements BinaryEvaluationParams <EvalBinaryClassBatchOp>, EvaluationMetricsCollector<BinaryClassMetrics> { @Override public BinaryClassMetrics collectMetrics() { return new BinaryClassMetrics(this.collect().get(0)); } }
能夠看到,其主要工做都是在基類BaseEvalClassBatchOp中完成,因此咱們會首先看BaseEvalClassBatchOp。
咱們仍是從 linkFrom 函數入手,其主要是作了幾件事:
具體代碼以下
@Override public T linkFrom(BatchOperator<?>... inputs) { BatchOperator<?> in = checkAndGetFirst(inputs); String labelColName = this.get(MultiEvaluationParams.LABEL_COL); String positiveValue = this.get(BinaryEvaluationParams.POS_LABEL_VAL_STR); // Judge the evaluation type from params. ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams()); DataSet<BaseMetricsSummary> res; switch (type) { case PRED_DETAIL: { String predDetailColName = this.get(MultiEvaluationParams.PREDICTION_DETAIL_COL); // 從輸入中提取某些列:"label","detailInput" DataSet<Row> data = in.select(new String[] {labelColName, predDetailColName}).getDataSet(); // 按照partition分別計算evaluation metrics res = calLabelPredDetailLocal(data, positiveValue, binary); break; } ...... } // 綜合reduce上述計算結果 DataSet<BaseMetricsSummary> metrics = res .reduce(new EvaluationUtil.ReduceBaseMetrics()); // 把最終數值輸入到 output table this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()), new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING}); return (T)this; } // 執行中一些變量以下 labelColName = "label" predDetailColName = "detailInput" type = {ClassificationEvaluationUtil$Type@2532} "PRED_DETAIL" binary = true positiveValue = null
由於後續代碼調用關係複雜,因此先給出一個調用關係:
本函數按照partition分別計算評估指標 evaluation metrics。是的,這代碼很短,可是有個地方須要注意。有時候越簡單的地方越容易疏漏。容易疏漏點是:
第一行代碼的結果 labels 是第二行代碼的參數,而並不是第二行主體。第二行代碼主體和第一行代碼主體同樣,都是data。
private static DataSet<BaseMetricsSummary> calLabelPredDetailLocal(DataSet<Row> data, final String positiveValue, oolean binary) { DataSet<Tuple2<Map<String, Integer>, String[]>> labels = data.flatMap(new FlatMapFunction<Row, String>() { @Override public void flatMap(Row row, Collector<String> collector) { TreeMap<String, Double> labelProbMap; if (EvaluationUtil.checkRowFieldNotNull(row)) { labelProbMap = EvaluationUtil.extractLabelProbMap(row); labelProbMap.keySet().forEach(collector::collect); collector.collect(row.getField(0).toString()); } } }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary, positiveValue)); return data .rebalance() .mapPartition(new CalLabelDetailLocal(binary)) .withBroadcastSet(labels, LABELS); }
calLabelPredDetailLocal中具體分爲三步驟:
下面具體看看。
在flatMap中,主要是從label列和prediction列中,取出全部labels(注意是取出labels的名字 ),發送給下游算子。
EvaluationUtil.extractLabelProbMap 做用就是解析輸入的json,得到具體detailInput中的信息。
下游算子是reduceGroup,因此Flink runtime會對這些labels自動去重。若是對這部分有興趣,能夠參見我以前介紹reduce的文章。CSDN : [源碼解析] Flink的groupBy和reduce究竟作了什麼 博客園 : [源碼解析] Flink的groupBy和reduce究竟作了什麼
程序中變量以下
row = {Row@8922} "prefix1,{"prefix1": 0.9, "prefix0": 0.1}" fields = {Object[2]@8925} 0 = "prefix1" 1 = "{"prefix1": 0.9, "prefix0": 0.1}" labelProbMap = {TreeMap@9008} size = 2 "prefix0" -> {Double@9015} 0.1 "prefix1" -> {Double@9017} 0.9 labelProbMap.keySet().forEach(collector::collect); //這裏發送 "prefix0", "prefix1" collector.collect(row.getField(0).toString()); // 這裏發送 "prefix1" // 由於下一個操做是reduceGroup,因此這些label會被runtime去重
主要功能是經過buildLabelIndexLabelArray去重labels,而後給每個label一個ID,最後結果是一個<labels, ID>的Map。
reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary, positiveValue));
DistinctLabelIndexMap的做用是從label列和prediction列中,取出全部不一樣的labels,返回一個<labels, ID>的map,根據後續代碼看,這個map是多分類纔用到。Get all the distinct labels from label column and prediction column, and return the map of labels and their IDs.
前面已經提到,這裏的參數rows已經被自動去重。
public static class DistinctLabelIndexMap implements GroupReduceFunction<String, Tuple2<Map<String, Integer>, String[]>> { ...... @Override public void reduce(Iterable<String> rows, Collector<Tuple2<Map<String, Integer>, String[]>> collector) throws Exception { HashSet<String> labels = new HashSet<>(); rows.forEach(labels::add); collector.collect(buildLabelIndexLabelArray(labels, binary, positiveValue)); } } // 變量爲 labels = {HashSet@9008} size = 2 0 = "prefix1" 1 = "prefix0" binary = true
buildLabelIndexLabelArray的做用是給每個label一個ID,獲得一個 <labels, ID>的map,最後返回是二元組(map, labels),即({prefix1=0, prefix0=1},[prefix1, prefix0])。
// Give each label an ID, return a map of label and ID. public static Tuple2<Map<String, Integer>, String[]> buildLabelIndexLabelArray(HashSet<String> set,boolean binary, String positiveValue) { String[] labels = set.toArray(new String[0]); Arrays.sort(labels, Collections.reverseOrder()); Map<String, Integer> map = new HashMap<>(labels.length); if (binary && null != positiveValue) { if (labels[1].equals(positiveValue)) { labels[1] = labels[0]; labels[0] = positiveValue; } map.put(labels[0], 0); map.put(labels[1], 1); } else { for (int i = 0; i < labels.length; i++) { map.put(labels[i], i); } } return Tuple2.of(map, labels); } // 程序變量以下 labels = {String[2]@9013} 0 = "prefix1" 1 = "prefix0" map = {HashMap@9014} size = 2 "prefix1" -> {Integer@9020} 0 "prefix0" -> {Integer@9021} 1
這裏主要功能是分區調用 CalLabelDetailLocal 來爲後來計算混淆矩陣作準備。
return data .rebalance() .mapPartition(new CalLabelDetailLocal(binary)) //這裏是業務所在 .withBroadcastSet(labels, LABELS);
具體工做是 CalLabelDetailLocal 完成的,其做用是分區調用getDetailStatistics
// Calculate the confusion matrix based on the label and predResult. static class CalLabelDetailLocal extends RichMapPartitionFunction<Row, BaseMetricsSummary> { private Tuple2<Map<String, Integer>, String[]> map; private boolean binary; @Override public void open(Configuration parameters) throws Exception { List<Tuple2<Map<String, Integer>, String[]>> list = getRuntimeContext().getBroadcastVariable(LABELS); this.map = list.get(0);// 前文生成的二元組(map, labels) } @Override public void mapPartition(Iterable<Row> rows, Collector<BaseMetricsSummary> collector) { // 調用到了 getDetailStatistics collector.collect(getDetailStatistics(rows, binary, map)); } }
getDetailStatistics 的做用是:初始化分類評估的度量指標 base classification evaluation metrics,累積計算混淆矩陣須要的數據。主要就是遍歷 rows 數據,提取每個item(好比 "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"),而後累積計算混淆矩陣所需數據。
// Initialize the base classification evaluation metrics. There are two cases: BinaryClassMetrics and MultiClassMetrics. private static BaseMetricsSummary getDetailStatistics(Iterable<Row> rows, String positiveValue, boolean binary, Tuple2<Map<String, Integer>, String[]> tuple) { BinaryMetricsSummary binaryMetricsSummary = null; MultiMetricsSummary multiMetricsSummary = null; Tuple2<Map<String, Integer>, String[]> labelIndexLabelArray = tuple; // 前文生成的二元組(map, labels) Iterator<Row> iterator = rows.iterator(); Row row = null; while (iterator.hasNext() && !checkRowFieldNotNull(row)) { row = iterator.next(); } Map<String, Integer> labelIndexMap = null; if (binary) { // 二分法在這裏 binaryMetricsSummary = new BinaryMetricsSummary( new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER], new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER], labelIndexLabelArray.f1, 0.0, 0L); } else { // labelIndexMap = labelIndexLabelArray.f0; // 前文生成的<labels, ID>Map看來是多分類纔用到。 multiMetricsSummary = new MultiMetricsSummary( new long[labelIndexMap.size()][labelIndexMap.size()], labelIndexLabelArray.f1, 0.0, 0L); } while (null != row) { if (checkRowFieldNotNull(row)) { TreeMap<String, Double> labelProbMap = extractLabelProbMap(row); String label = row.getField(0).toString(); if (ArrayUtils.indexOf(labelIndexLabelArray.f1, label) >= 0) { if (binary) { // 二分法在這裏 updateBinaryMetricsSummary(labelProbMap, label, binaryMetricsSummary); } else { updateMultiMetricsSummary(labelProbMap, label, labelIndexMap, multiMetricsSummary); } } } row = iterator.hasNext() ? iterator.next() : null; } return binary ? binaryMetricsSummary : multiMetricsSummary; } //變量以下 tuple = {Tuple2@9252} "({prefix1=0, prefix0=1},[prefix1, prefix0])" f0 = {HashMap@9257} size = 2 "prefix1" -> {Integer@9264} 0 "prefix0" -> {Integer@9266} 1 f1 = {String[2]@9258} 0 = "prefix1" 1 = "prefix0" row = {Row@9271} "prefix1,{"prefix1": 0.8, "prefix0": 0.2}" fields = {Object[2]@9276} 0 = "prefix1" 1 = "{"prefix1": 0.8, "prefix0": 0.2}" labelIndexLabelArray = {Tuple2@9240} "({prefix1=0, prefix0=1},[prefix1, prefix0])" f0 = {HashMap@9288} size = 2 "prefix1" -> {Integer@9294} 0 "prefix0" -> {Integer@9296} 1 f1 = {String[2]@9242} 0 = "prefix1" 1 = "prefix0" labelProbMap = {TreeMap@9342} size = 2 "prefix0" -> {Double@9378} 0.1 "prefix1" -> {Double@9380} 0.9
先回憶下混淆矩陣:
預測值 0 | 預測值 1 | |||
---|---|---|---|---|
真實值 0 | TN | FP | ||
真實值 1 | FN | TP |
針對混淆矩陣,BinaryMetricsSummary 的做用是Save the evaluation data for binary classification。函數具體計算思路是:
把 [0,1] 分紅ClassificationEvaluationUtil.DETAIL_BIN_NUMBER(100000)這麼多桶(bin)。因此binaryMetricsSummary的positiveBin/negativeBin分別是兩個100000的數組。若是某一個 sample 爲 正例(positive value) 的機率是 p, 則該 sample 對應的 bin index 就是 p * 100000。若是 p 被預測爲正例(positive value) ,則positiveBin[index]++,不然就是被預測爲負例(negative value) ,則negativeBin[index]++。positiveBin就是 TP + FP,negativeBin就是 TN + FN。
因此這裏會遍歷輸入,若是某一個輸入(以"prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"
爲例),0.9 是prefix1(正例) 的機率,0.1 是爲prefix0(負例) 的機率。
具體對應咱們示例代碼的5個採樣,分類以下:
Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"), positiveBin 90000處+1 Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"), positiveBin 80000處+1 Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"), positiveBin 70000處+1 Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"), negativeBin 75000處+1 Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}") negativeBin 60000處+1
具體代碼以下
public static void updateBinaryMetricsSummary(TreeMap<String, Double> labelProbMap, String label, BinaryMetricsSummary binaryMetricsSummary) { binaryMetricsSummary.total++; binaryMetricsSummary.logLoss += extractLogloss(labelProbMap, label); double d = labelProbMap.get(binaryMetricsSummary.labels[0]); int idx = d == 1.0 ? ClassificationEvaluationUtil.DETAIL_BIN_NUMBER - 1 : (int)Math.floor(d * ClassificationEvaluationUtil.DETAIL_BIN_NUMBER); if (idx >= 0 && idx < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER) { if (label.equals(binaryMetricsSummary.labels[0])) { binaryMetricsSummary.positiveBin[idx] += 1; } else if (label.equals(binaryMetricsSummary.labels[1])) { binaryMetricsSummary.negativeBin[idx] += 1; } else { ..... } } } private static double extractLogloss(TreeMap<String, Double> labelProbMap, String label) { Double prob = labelProbMap.get(label); prob = null == prob ? 0. : prob; return -Math.log(Math.max(Math.min(prob, 1 - LOG_LOSS_EPS), LOG_LOSS_EPS)); } // 變量以下 ClassificationEvaluationUtil.DETAIL_BIN_NUMBER=100000 // 當 "prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}" 時候 labelProbMap = {TreeMap@9305} size = 2 "prefix0" -> {Double@9331} 0.1 "prefix1" -> {Double@9333} 0.9 d = 0.9 idx = 90000 binaryMetricsSummary = {BinaryMetricsSummary@9262} labels = {String[2]@9242} 0 = "prefix1" 1 = "prefix0" total = 1 positiveBin = {long[100000]@9263} // 90000處+1 negativeBin = {long[100000]@9264} logLoss = 0.10536051565782628 // 當 "prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}" 時候 labelProbMap = {TreeMap@9514} size = 2 "prefix0" -> {Double@9546} 0.4 "prefix1" -> {Double@9547} 0.6 d = 0.6 idx = 60000 binaryMetricsSummary = {BinaryMetricsSummary@9262} labels = {String[2]@9242} 0 = "prefix1" 1 = "prefix0" total = 2 positiveBin = {long[100000]@9263} negativeBin = {long[100000]@9264} // 60000處+1 logLoss = 1.0216512475319812
ReduceBaseMetrics做用是把局部計算的 BaseMetrics 聚合起來。
DataSet<BaseMetricsSummary> metrics = res .reduce(new EvaluationUtil.ReduceBaseMetrics());
ReduceBaseMetrics以下
public static class ReduceBaseMetrics implements ReduceFunction<BaseMetricsSummary> { @Override public BaseMetricsSummary reduce(BaseMetricsSummary t1, BaseMetricsSummary t2) throws Exception { return null == t1 ? t2 : t1.merge(t2); } }
具體計算是在BinaryMetricsSummary.merge,其做用就是Merge the bins, and add the logLoss。
@Override public BinaryMetricsSummary merge(BinaryMetricsSummary binaryClassMetrics) { for (int i = 0; i < this.positiveBin.length; i++) { this.positiveBin[i] += binaryClassMetrics.positiveBin[i]; } for (int i = 0; i < this.negativeBin.length; i++) { this.negativeBin[i] += binaryClassMetrics.negativeBin[i]; } this.logLoss += binaryClassMetrics.logLoss; this.total += binaryClassMetrics.total; return this; } // 程序變量是 this = {BinaryMetricsSummary@9316} labels = {String[2]@9322} 0 = "prefix1" 1 = "prefix0" total = 2 positiveBin = {long[100000]@9320} negativeBin = {long[100000]@9323} logLoss = 1.742969305058623
this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()), new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING});
當歸並全部BaseMetrics以後,獲得了total BaseMetrics,計算indexes,存入到params。
public static class SaveDataAsParams implements FlatMapFunction<BaseMetricsSummary, Row> { @Override public void flatMap(BaseMetricsSummary t, Collector<Row> collector) throws Exception { collector.collect(t.toMetrics().serialize()); } }
實際業務在BinaryMetricsSummary.toMetrics中完成,即基於bin的信息計算,獲得confusionMatrix array, threshold array, rocCurve/recallPrecisionCurve/LiftChart等等,而後存儲到params。
public BinaryClassMetrics toMetrics() { Params params = new Params(); // 生成若干曲線,好比rocCurve/recallPrecisionCurve/LiftChart Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> matrixThreCurve = extractMatrixThreCurve(positiveBin, negativeBin, total); // 依據曲線內容計算而且存儲 AUC/PRC/KS setCurveAreaParams(params, matrixThreCurve.f2); // 對生成的rocCurve/recallPrecisionCurve/LiftChart輸出進行抽樣 Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> sampledMatrixThreCurve = sample( PROBABILITY_INTERVAL, matrixThreCurve); // 依據抽樣後的輸出存儲 RocCurve/RecallPrecisionCurve/LiftChar setCurvePointsParams(params, sampledMatrixThreCurve); ConfusionMatrix[] matrices = sampledMatrixThreCurve.f0; // 存儲正例樣本的度量指標 setComputationsArrayParams(params, sampledMatrixThreCurve.f1, sampledMatrixThreCurve.f0); // 存儲Logloss setLoglossParams(params, logLoss, total); // Pick the middle point where threshold is 0.5. int middleIndex = getMiddleThresholdIndex(sampledMatrixThreCurve.f1); setMiddleThreParams(params, matrices[middleIndex], labels); return new BinaryClassMetrics(params); }
extractMatrixThreCurve是全文重點。這裏是 Extract the bins who are not empty, keep the middle threshold 0.5,而後初始化了 RocCurve, Recall-Precision Curve and Lift Curve,計算出ConfusionMatrix array(混淆矩陣), threshold array, rocCurve/recallPrecisionCurve/LiftChart.。
/** * Extract the bins who are not empty, keep the middle threshold 0.5. * Initialize the RocCurve, Recall-Precision Curve and Lift Curve. * RocCurve: (FPR, TPR), starts with (0,0). Recall-Precision Curve: (recall, precision), starts with (0, p), p is the precision with the lowest. LiftChart: (TP+FP/total, TP), starts with (0,0). confusion matrix = [TP FP][FN * TN]. * * @param positiveBin positiveBins. * @param negativeBin negativeBins. * @param total sample number * @return ConfusionMatrix array, threshold array, rocCurve/recallPrecisionCurve/LiftChart. */ static Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> extractMatrixThreCurve(long[] positiveBin, long[] negativeBin, long total) { ArrayList<Integer> effectiveIndices = new ArrayList<>(); long totalTrue = 0, totalFalse = 0; // 計算totalTrue,totalFalse,effectiveIndices for (int i = 0; i < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER; i++) { if (0L != positiveBin[i] || 0L != negativeBin[i] || i == ClassificationEvaluationUtil.DETAIL_BIN_NUMBER / 2) { effectiveIndices.add(i); totalTrue += positiveBin[i]; totalFalse += negativeBin[i]; } } // 以咱們例子,獲得 effectiveIndices = {ArrayList@9273} size = 6 0 = {Integer@9277} 50000 //這裏加入了中間點 1 = {Integer@9278} 60000 2 = {Integer@9279} 70000 3 = {Integer@9280} 75000 4 = {Integer@9281} 80000 5 = {Integer@9282} 90000 totalTrue = 3 totalFalse = 2 // 繼續初始化,生成若干curve final int length = effectiveIndices.size(); final int newLen = length + 1; final double m = 1.0 / ClassificationEvaluationUtil.DETAIL_BIN_NUMBER; EvaluationCurvePoint[] rocCurve = new EvaluationCurvePoint[newLen]; EvaluationCurvePoint[] recallPrecisionCurve = new EvaluationCurvePoint[newLen]; EvaluationCurvePoint[] liftChart = new EvaluationCurvePoint[newLen]; ConfusionMatrix[] data = new ConfusionMatrix[newLen]; double[] threshold = new double[newLen]; long curTrue = 0; long curFalse = 0; // 以咱們例子,獲得 length = 6 newLen = 7 m = 1.0E-5 // 計算, 其中rocCurve,recallPrecisionCurve,liftChart 均可以從代碼中看出 for (int i = 1; i < newLen; i++) { int index = effectiveIndices.get(length - i); curTrue += positiveBin[index]; curFalse += negativeBin[index]; threshold[i] = index * m; // 計算出混淆矩陣 data[i] = new ConfusionMatrix( new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}}); double tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue); // 好比當 90000 這點,獲得 curTrue = 1 curFalse = 0 i = 1 index = 90000 tpr = 0.3333333333333333。totalTrue = 3 totalFalse = 2, // 咱們也知道,TPR = TP / (TP + FN) ,因此能夠計算 tpr = 1 / 3 rocCurve[i] = new EvaluationCurvePoint(totalFalse == 0 ? 1.0 : 1.0 * curFalse / totalFalse, tpr, threshold[i]); recallPrecisionCurve[i] = new EvaluationCurvePoint(tpr, curTrue + curTrue == 0 ? 1.0 : 1.0 * curTrue / (curTrue + curFalse), threshold[i]); liftChart[i] = new EvaluationCurvePoint(1.0 * (curTrue + curFalse) / total, curTrue, threshold[i]); } // 以咱們例子,獲得 curTrue = 3 curFalse = 2 threshold = {double[7]@9349} 0 = 0.0 1 = 0.9 2 = 0.8 3 = 0.7500000000000001 4 = 0.7000000000000001 5 = 0.6000000000000001 6 = 0.5 rocCurve = {EvaluationCurvePoint[7]@9315} 1 = {EvaluationCurvePoint@9440} x = 0.0 y = 0.3333333333333333 p = 0.9 2 = {EvaluationCurvePoint@9448} x = 0.0 y = 0.6666666666666666 p = 0.8 3 = {EvaluationCurvePoint@9449} x = 0.5 y = 0.6666666666666666 p = 0.7500000000000001 4 = {EvaluationCurvePoint@9450} x = 0.5 y = 1.0 p = 0.7000000000000001 5 = {EvaluationCurvePoint@9451} x = 1.0 y = 1.0 p = 0.6000000000000001 6 = {EvaluationCurvePoint@9452} x = 1.0 y = 1.0 p = 0.5 recallPrecisionCurve = {EvaluationCurvePoint[7]@9320} 1 = {EvaluationCurvePoint@9444} x = 0.3333333333333333 y = 1.0 p = 0.9 2 = {EvaluationCurvePoint@9453} x = 0.6666666666666666 y = 1.0 p = 0.8 3 = {EvaluationCurvePoint@9454} x = 0.6666666666666666 y = 0.6666666666666666 p = 0.7500000000000001 4 = {EvaluationCurvePoint@9455} x = 1.0 y = 0.75 p = 0.7000000000000001 5 = {EvaluationCurvePoint@9456} x = 1.0 y = 0.6 p = 0.6000000000000001 6 = {EvaluationCurvePoint@9457} x = 1.0 y = 0.6 p = 0.5 liftChart = {EvaluationCurvePoint[7]@9325} 1 = {EvaluationCurvePoint@9458} x = 0.2 y = 1.0 p = 0.9 2 = {EvaluationCurvePoint@9459} x = 0.4 y = 2.0 p = 0.8 3 = {EvaluationCurvePoint@9460} x = 0.6 y = 2.0 p = 0.7500000000000001 4 = {EvaluationCurvePoint@9461} x = 0.8 y = 3.0 p = 0.7000000000000001 5 = {EvaluationCurvePoint@9462} x = 1.0 y = 3.0 p = 0.6000000000000001 6 = {EvaluationCurvePoint@9463} x = 1.0 y = 3.0 p = 0.5 data = {ConfusionMatrix[7]@9339} 0 = {ConfusionMatrix@9486} longMatrix = {LongMatrix@9488} matrix = {long[2][]@9491} 0 = {long[2]@9492} 0 = 0 1 = 0 1 = {long[2]@9493} 0 = 3 1 = 2 rowNum = 2 colNum = 2 labelCnt = 2 total = 5 actualLabelFrequency = {long[2]@9489} 0 = 3 1 = 2 predictLabelFrequency = {long[2]@9490} 0 = 0 1 = 5 tpCount = 2.0 tnCount = 2.0 fpCount = 3.0 fnCount = 3.0 1 = {ConfusionMatrix@9435} longMatrix = {LongMatrix@9469} matrix = {long[2][]@9472} 0 = {long[2]@9474} 0 = 1 1 = 0 1 = {long[2]@9475} 0 = 2 1 = 2 rowNum = 2 colNum = 2 labelCnt = 2 total = 5 actualLabelFrequency = {long[2]@9470} 0 = 3 1 = 2 predictLabelFrequency = {long[2]@9471} 0 = 1 1 = 4 tpCount = 3.0 tnCount = 3.0 fpCount = 2.0 fnCount = 2.0 ...... threshold[0] = 1.0; data[0] = new ConfusionMatrix(new long[][] {{0, 0}, {totalTrue, totalFalse}}); rocCurve[0] = new EvaluationCurvePoint(0, 0, threshold[0]); recallPrecisionCurve[0] = new EvaluationCurvePoint(0, recallPrecisionCurve[1].getY(), threshold[0]); liftChart[0] = new EvaluationCurvePoint(0, 0, threshold[0]); return Tuple3.of(data, threshold, new EvaluationCurve[] {new EvaluationCurve(rocCurve), new EvaluationCurve(recallPrecisionCurve), new EvaluationCurve(liftChart)}); }
這裏再給你們講講混淆矩陣如何計算,這裏思路比較繞。
調用之處是:
// 調用之處 data[i] = new ConfusionMatrix( new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}}); // 調用時候各類賦值 i = 1 index = 90000 totalTrue = 3 totalFalse = 2 curTrue = 1 curFalse = 0
獲得原始矩陣,如下都有cur,說明只針對當前點來講。
curTrue = 1 | curFalse = 0 |
totalTrue - curTrue = 2 | totalFalse - curFalse = 2 |
後續ConfusionMatrix計算中,由此能夠獲得
actualLabelFrequency = longMatrix.getColSums(); predictLabelFrequency = longMatrix.getRowSums(); actualLabelFrequency = {long[2]@9322} 0 = 3 1 = 2 predictLabelFrequency = {long[2]@9323} 0 = 1 1 = 4
能夠看出來,Alink算法認爲:每列的sum和實際標籤有關;每行sum和預測標籤有關。
獲得新矩陣以下
predictLabelFrequency | |||
---|---|---|---|
curTrue = 1 | curFalse = 0 | 1 = curTrue + curFalse | |
totalTrue - curTrue = 2 | totalFalse - curFalse = 2 | 4 = total - curTrue - curFalse | |
actualLabelFrequency | 3 = totalTrue | 2 = totalFalse |
後續計算將要基於這些來計算:
計算中就用到longMatrix 對角線上的數據,即longMatrix(0)(0)和 longMatrix(1)(1)。必定要注意,這裏考慮的都是 當前狀態 (畫重點強調)。
longMatrix(0)(0) :curTrue
longMatrix(1)(1) :totalFalse - curFalse
totalFalse :( TN + FN )
totalTrue :( TP + FP )
double numTrueNegative(Integer labelIndex) { // labelIndex爲 0 時候,return 1 + 5 - 1 - 3 = 2; // labelIndex爲 1 時候,return 2 + 5 - 4 - 2 = 1; return null == labelIndex ? tnCount : longMatrix.getValue(labelIndex, labelIndex) + total - predictLabelFrequency[labelIndex] - actualLabelFrequency[labelIndex]; } double numTruePositive(Integer labelIndex) { // labelIndex爲 0 時候,return 1; 這個是 curTrue,就是真實標籤是True,判別也是True。是TP // labelIndex爲 1 時候,return 2; 這個是 totalFalse - curFalse,總判別錯 - 當前判別錯。這就意味着「原本判別錯了可是當前沒有發現」,因此認爲在當前狀態下,這也算是TP return null == labelIndex ? tpCount : longMatrix.getValue(labelIndex, labelIndex); } double numFalseNegative(Integer labelIndex) { // labelIndex爲 0 時候,return 3 - 1; // actualLabelFrequency[0] = totalTrue。因此return totalTrue - curTrue,即當前「所有正確」中沒有「判別爲正確」,這個就能夠認爲是「判別錯了且判別爲負」 // labelIndex爲 1 時候,return 2 - 2; // actualLabelFrequency[1] = totalFalse。因此return totalFalse - ( totalFalse - curFalse ) = curFalse return null == labelIndex ? fnCount : actualLabelFrequency[labelIndex] - longMatrix.getValue(labelIndex, labelIndex); } double numFalsePositive(Integer labelIndex) { // labelIndex爲 0 時候,return 1 - 1; // predictLabelFrequency[0] = curTrue + curFalse。 // 因此 return = curTrue + curFalse - curTrue = curFalse = current( TN + FN ) 這能夠認爲是判斷錯了實際是正確標籤 // labelIndex爲 1 時候,return 4 - 2; // predictLabelFrequency[1] = total - curTrue - curFalse。 // 因此 return = total - curTrue - curFalse - (totalFalse - curFalse) = totalTrue - curTrue = ( TP + FP ) - currentTP = currentFP return null == labelIndex ? fpCount : predictLabelFrequency[labelIndex] - longMatrix.getValue(labelIndex, labelIndex); } // 最後獲得 tpCount = 3.0 tnCount = 3.0 fpCount = 2.0 fnCount = 2.0
// 具體計算 public ConfusionMatrix(LongMatrix longMatrix) { longMatrix = {LongMatrix@9297} 0 = {long[2]@9324} 0 = 1 1 = 0 1 = {long[2]@9325} 0 = 2 1 = 2 this.longMatrix = longMatrix; labelCnt = this.longMatrix.getRowNum(); // 這裏就是計算 actualLabelFrequency = longMatrix.getColSums(); predictLabelFrequency = longMatrix.getRowSums(); actualLabelFrequency = {long[2]@9322} 0 = 3 1 = 2 predictLabelFrequency = {long[2]@9323} 0 = 1 1 = 4 labelCnt = 2 total = 5 total = longMatrix.getTotal(); for (int i = 0; i < labelCnt; i++) { tnCount += numTrueNegative(i); tpCount += numTruePositive(i); fnCount += numFalseNegative(i); fpCount += numFalsePositive(i); } }
Alink原有python示例代碼中,Stream部分是沒有輸出的,由於MemSourceStreamOp沒有和時間相關聯,而Alink中沒有提供基於時間的StreamOperator,因此只能本身仿照MemSourceBatchOp寫了一個。雖然代碼有些醜,可是至少能夠提供輸出,這樣就可以調試。
public class EvalBinaryClassExampleStream { AlgoOperator getData(boolean isBatch) { Row[] rows = new Row[]{ Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}") }; String[] schema = new String[]{"label", "detailInput"}; if (isBatch) { return new MemSourceBatchOp(rows, schema); } else { return new TimeMemSourceStreamOp(rows, schema, new EvalBinaryStreamSource()); } } public static void main(String[] args) throws Exception { EvalBinaryClassExampleStream test = new EvalBinaryClassExampleStream(); StreamOperator streamData = (StreamOperator) test.getData(false); StreamOperator sOp = new EvalBinaryClassStreamOp() .setLabelCol("label") .setPredictionDetailCol("detailInput") .setTimeInterval(1) .linkFrom(streamData); sOp.print(); StreamOperator.execute(); } }
這個是我本身炮製的。借鑑了MemSourceStreamOp。
public final class TimeMemSourceStreamOp extends StreamOperator<TimeMemSourceStreamOp> { public TimeMemSourceStreamOp(Row[] rows, String[] colNames, EvalBinaryStrSource source) { super(null); init(source, Arrays.asList(rows), colNames); } private void init(EvalBinaryStreamSource source, List <Row> rows, String[] colNames) { Row first = rows.iterator().next(); int arity = first.getArity(); TypeInformation <?>[] types = new TypeInformation[arity]; for (int i = 0; i < arity; ++i) { types[i] = TypeExtractor.getForObject(first.getField(i)); } init(source, colNames, types); } private void init(EvalBinaryStreamSource source, String[] colNames, TypeInformation <?>[] colTypes) { DataStream <Row> dastr = MLEnvironmentFactory.get(getMLEnvironmentId()) .getStreamExecutionEnvironment().addSource(source); StringBuilder sbd = new StringBuilder(); sbd.append(colNames[0]); for (int i = 1; i < colNames.length; i++) { sbd.append(",").append(colNames[i]); } this.setOutput(dastr, colNames, colTypes); } @Override public TimeMemSourceStreamOp linkFrom(StreamOperator<?>... inputs) { return null; } }
定時提供Row,加入了隨機數,讓機率有變化。
class EvalBinaryStreamSource extends RichSourceFunction[Row] { override def run(ctx: SourceFunction.SourceContext[Row]) = { while (true) { val rdm = Math.random() // 這裏加入了隨機數,讓機率有變化 val rows: Array[Row] = Array[Row]( Row.of("prefix1", "{\"prefix1\": " + rdm + ", \"prefix0\": " + (1-rdm) + "}"), Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"), Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"), Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"), Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}")) for(row <- rows) { println(s"當前值:$row") ctx.collect(row) } Thread.sleep(1000) } } override def cancel() = ??? }
Alink流處理類是 EvalBinaryClassStreamOp,主要工做在其基類 BaseEvalClassStreamOp,因此咱們重點看後者。
public class BaseEvalClassStreamOp<T extends BaseEvalClassStreamOp<T>> extends StreamOperator<T> { @Override public T linkFrom(StreamOperator<?>... inputs) { StreamOperator<?> in = checkAndGetFirst(inputs); String labelColName = this.get(MultiEvaluationStreamParams.LABEL_COL); String positiveValue = this.get(BinaryEvaluationStreamParams.POS_LABEL_VAL_STR); Integer timeInterval = this.get(MultiEvaluationStreamParams.TIME_INTERVAL); ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams()); DataStream<BaseMetricsSummary> statistics; switch (type) { case PRED_RESULT: { ...... } case PRED_DETAIL: { String predDetailColName = this.get(MultiEvaluationStreamParams.PREDICTION_DETAIL_COL); // PredDetailLabel eval = new PredDetailLabel(positiveValue, binary); // 獲取輸入數據,重點是timeWindowAll statistics = in.select(new String[] {labelColName, predDetailColName}) .getDataStream() .timeWindowAll(Time.of(timeInterval, TimeUnit.SECONDS)) .apply(eval); break; } } // 把各個窗口的數據累積到 totalStatistics,注意,這裏是新變量了。 DataStream<BaseMetricsSummary> totalStatistics = statistics .map(new EvaluationUtil.AllDataMerge()) .setParallelism(1); // 並行度設置爲1 // 基於兩種 bins 計算&序列化,獲得當前的 statistics DataStream<Row> windowOutput = statistics.map( new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.WINDOW.f0)); // 基於bins計算&序列化,獲得累積的 totalStatistics DataStream<Row> allOutput = totalStatistics.map( new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.ALL.f0)); // "當前" 和 "累積" 作聯合,最終返回 DataStream<Row> union = windowOutput.union(allOutput); this.setOutput(union, new String[] {ClassificationEvaluationUtil.STATISTICS_OUTPUT, DATA_OUTPUT}, new TypeInformation[] {Types.STRING, Types.STRING}); return (T)this; } }
具體業務是:
static class PredDetailLabel implements AllWindowFunction<Row, BaseMetricsSummary, TimeWindow> { @Override public void apply(TimeWindow timeWindow, Iterable<Row> rows, Collector<BaseMetricsSummary> collector) throws Exception { HashSet<String> labels = new HashSet<>(); // 首先仍是獲取 labels 名字 for (Row row : rows) { if (EvaluationUtil.checkRowFieldNotNull(row)) { labels.addAll(EvaluationUtil.extractLabelProbMap(row).keySet()); labels.add(row.getField(0).toString()); } } labels = {HashSet@9757} size = 2 0 = "prefix1" 1 = "prefix0" // 以前介紹過,buildLabelIndexLabelArray 去重 "labels名字",而後給每個label一個ID,最後結果是一個<labels, ID>Map。 // getDetailStatistics 遍歷 rows 數據,累積計算混淆矩陣所需數據( "TP + FN" / "TN + FP")。 if (labels.size() > 0) { collector.collect( getDetailStatistics(rows, binary, buildLabelIndexLabelArray(labels, binary, positiveValue))); } } }
EvaluationUtil.AllDataMerge 把各個窗口的數據累積
/** * Merge data from different windows. */ public static class AllDataMerge implements MapFunction<BaseMetricsSummary, BaseMetricsSummary> { private BaseMetricsSummary statistics; @Override public BaseMetricsSummary map(BaseMetricsSummary value) { this.statistics = (null == this.statistics ? value : this.statistics.merge(value)); return this.statistics; } }
SaveDataStream具體調用的函數以前批處理介紹過,實際業務在BinaryMetricsSummary.toMetrics,即基於bin的信息計算,存儲到params。
這裏與批處理不一樣的是直接就把"構建出的度量信息「返回給用戶。
public static class SaveDataStream implements MapFunction<BaseMetricsSummary, Row> { @Override public Row map(BaseMetricsSummary baseMetricsSummary) throws Exception { BaseMetricsSummary metrics = baseMetricsSummary; BaseMetrics baseMetrics = metrics.toMetrics(); Row row = baseMetrics.serialize(); return Row.of(funtionName, row.getField(0)); } } // 最後獲得的 row 其實就是最終返回給用戶的度量信息 row = {Row@10008} "{"PRC":"0.9164636268708667","SensitivityArray":"[0.38461538461538464,0.6923076923076923,0.6923076923076923,1.0,1.0,1.0]","ConfusionMatrix":"[[13,8],[0,0]]","MacroRecall":"0.5","MacroSpecificity":"0.5","FalsePositiveRateArray":"[0.0,0.0,0.5,0.5,1.0,1.0]" ...... 還有不少其餘的
DataStream<Row> windowOutput = statistics.map( new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.WINDOW.f0)); DataStream<Row> allOutput = totalStatistics.map( new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.ALL.f0)); DataStream<Row> union = windowOutput.union(allOutput);
最後返回兩種統計數據
all|{"PRC":"0.7341146115890359","SensitivityArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,0.7333333333333333,0.8,0.8,0.8666666666666667,0.8666666666666667,0.9333333333333333,1.0]","ConfusionMatrix":"[[13,10],[2,0]]","MacroRecall":"0.43333333333333335","MacroSpecificity":"0.43333333333333335","FalsePositiveRateArray":"[0.0,0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.0]","TruePositiveRateArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,0.7333333333333333,0.8,0.8,0.8666666666666667,0.8666666666666667,0.9333333333333333,1.0]","AUC":"0.5666666666666667","MacroAccuracy":"0.52", ......
window|{"PRC":"0.7638888888888888","SensitivityArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]","ConfusionMatrix":"[[3,2],[0,0]]","MacroRecall":"0.5","MacroSpecificity":"0.5","FalsePositiveRateArray":"[0.0,0.5,0.5,0.5,1.0,1.0]","TruePositiveRateArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]","AUC":"0.6666666666666666","MacroAccuracy":"0.6","RecallArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]","KappaArray":"[0.28571428571428564,-0.15384615384615377,0.1666666666666666,0.5454545454545455,0.0,0.0]","MicroFalseNegativeRate":"0.4","WeightedRecall":"0.6","WeightedPrecision":"0.36","Recall":"1.0","MacroPrecision":"0.3",......
[[白話解析] 經過實例來梳理概念 :準確率 (Accuracy)、精準率(Precision)、召回率(Recall) 和 F值(F-Measure)](