Alink漫談(十三) :在線學習算法FTRL 之 具體實現

Alink漫談(十三) :在線學習算法FTRL 之 具體實現

0x00 摘要

Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習算法平臺,是業界首個同時支持批式算法、流式算法的機器學習平臺。本文和上文一塊兒介紹了在線學習算法 FTRL 在Alink中是如何實現的,但願對你們有所幫助。html

0x01 回顧

書接上回 Alink漫談(十二) :在線學習算法FTRL 之 總體設計 。到目前爲止,已經處理完畢輸入,接下來就是在線訓練。訓練優化的主要目標是找到一個方向,參數朝這個方向移動以後使得損失函數的值可以減少,這個方向每每由一階偏導或者二階偏導各類組合求得。java

爲了讓你們更好理解,咱們再次貼出總體流程圖:算法

在這裏插入圖片描述

0x02 在線訓練

在線訓練主要邏輯是:apache

  • 1)加載初始化模型到 dataBridge;dataBridge = DirectReader.collect(model);
  • 2)獲取相關參數。好比vectorSize默認是30000,是否 hasInterceptItem;
  • 3)獲取切分信息。splitInfo = getSplitInfo(featureSize, hasInterceptItem, parallelism); 下面立刻會用到。
  • 4)切分高維向量。初始化數據作了特徵哈希,會產生高維向量,這裏須要進行切割。 initData.flatMap(new SplitVector(splitInfo, hasInterceptItem, vectorSize,vectorTrainIdx, featureIdx, labelIdx));
  • 5)構建一個 IterativeStream.ConnectedIterativeStreams iteration,這樣會構建(或者說鏈接)兩個數據流:反饋流和訓練流;
  • 6)用iteration來構建迭代體 iterativeBody,其包括兩部分:CalcTask,ReduceTask;
    • 6.1)CalcTask分紅兩個部分。flatMap1 是分佈計算FTRL迭代須要的predict,flatMap2 是FTRL的更新參數部分;
    • 6.2)ReduceTask分爲兩個功能:「歸併這些predict計算結果「 / 」若是知足條件則歸併模型 & 向下遊算子輸出模型「;
  • 7)result = iterativeBody.filter;基本是以時間間隔爲標準來判斷(也能夠認爲是時間驅動),"時間未過時&向量有意義" 的數據將被髮送回反饋數據流,繼續迭代,回到步驟 6),進入flatMap2
  • 8)output = iterativeBody.filter;符合標準(時間過時了)的數據將跳出迭代,而後算法會調用WriteModel將LineModelData轉換爲多條Row,轉發給下游operator(也就是在線預測階段);即定時把模型更新給在線預測階段

2.1 預置模型

前面說到,FTRL先要訓練出一個邏輯迴歸模型做爲FTRL算法的初始模型,這是爲了系統冷啓動的須要。api

2.1.1 訓練模型

具體邏輯迴歸模型設定/訓練是 :app

// train initial batch model
LogisticRegressionTrainBatchOp lr = new LogisticRegressionTrainBatchOp()
            .setVectorCol(vecColName)
            .setLabelCol(labelColName)
            .setWithIntercept(true)
            .setMaxIter(10);
BatchOperator<?> initModel = featurePipelineModel.transform(trainBatchData).link(lr);

訓練好以後,模型信息是DataSet 類型,位於變量 BatchOperator<?> initModel之中,這是一個批處理算子。 機器學習

2.1.2 加載模型

FtrlTrainStreamOp將initModel做爲初始化參數。分佈式

FtrlTrainStreamOp model = new FtrlTrainStreamOp(initModel)

在FtrlTrainStreamOp構造函數中會加載這個模型;ide

dataBridge = DirectReader.collect(initModel);

具體加載時經過MemoryDataBridge直接獲取初始化模型DataSet中的數據。函數

public MemoryDataBridge generate(BatchOperator batchOperator, Params globalParams) {
   return new MemoryDataBridge(batchOperator.collect());
}

2.2 分割高維向量

從前文可知,Alink的FTRL算法設置的特徵向量維度是30000。因此算法第一步就是切分高維度向量,以便分佈式計算。

String vecColName = "vec";
int numHashFeatures = 30000;

首先要獲取切分信息,代碼以下,就是將特徵數目featureSize 除以 並行度parallelism,而後獲得了每一個task對應係數的初始位置。

private static int[] getSplitInfo(int featureSize, boolean hasInterceptItem, int parallelism) {
    int coefSize = (hasInterceptItem) ? featureSize + 1 : featureSize;
    int subSize = coefSize / parallelism;
    int[] poses = new int[parallelism + 1];
    int offset = coefSize % parallelism;
    for (int i = 0; i < offset; ++i) {
        poses[i + 1] = poses[i] + subSize + 1;
    }
    for (int i = offset; i < parallelism; ++i) {
        poses[i + 1] = poses[i] + subSize;
    }
    return poses;
}
//程序運行時變量以下
featureSize = 30000
hasInterceptItem = true
parallelism = 4
coefSize = 30001
subSize = 7500
poses = {int[5]@11660} 
 0 = 0
 1 = 7501
 2 = 15001
 3 = 22501
 4 = 30001
offset = 1

而後根據切分信息對高維向量進行切割。

// Tuple5<SampleId, taskId, numSubVec, SubVec, label>
DataStream<Tuple5<Long, Integer, Integer, Vector, Object>> input
          = initData.flatMap(new SplitVector(splitInfo, hasInterceptItem, vectorSize,
                            vectorTrainIdx, featureIdx, labelIdx))
          .partitionCustom(new CustomBlockPartitioner(), 1);

具體切分在SplitVector.flatMap函數完成,結果就是把一個高維度向量分割給各個CalcTask

代碼摘要以下:

public void flatMap(Row row, Collector<Tuple5<Long, Integer, Integer, Vector, Object>> collector) throws Exception {
				long sampleId = counter;
        counter += parallelism;
        Vector vec;
        if (vectorTrainIdx == -1) {
           .....
        } else {
            // 輸入row的第vectorTrainIdx個field就是那個30000大小的係數向量
            vec = VectorUtil.getVector(row.getField(vectorTrainIdx));
        }

        if (vec instanceof SparseVector) {
            Map<Integer, Vector> tmpVec = new HashMap<>();
            for (int i = 0; i < indices.length; ++i) {
              .....
              // 此處迭代完成後,tmpVec中就是task number個元素,每個元素是分割好的係數向量。
            }
            for (Integer key : tmpVec.keySet()) {
                //此處遍歷,給後面全部CalcTask發送五元組數據。
                collector.collect(Tuple5.of(sampleId, key, subNum, tmpVec.get(key), row.getField(labelIdx)));
            }
        } else {
         ......
        }
    }
}

這個Tuple5.of(sampleId, key, subNum, tmpVec.get(key), row.getField(labelIdx) )就是後面CalcTask的輸入。

2.3 迭代訓練

此處理論上有如下幾個重點:

  • 預測方法:在每一輪t中,針對特徵樣本xt,以及迭代後(第一次則是給定初值)的模型參數wt,咱們能夠預測該樣本的標記值:pt=σ(wt,xt),其中σ(a)=1/(1+exp(−a))是一個sigmoid函數。

  • 損失函數:對一個特徵樣本xt,其對應的標記爲yt ∈ 0,1,則經過 logistic loss 來做爲損失函數。

  • 迭代公式:咱們的目的是使得損失函數儘量的小,便可以採用極大似然估計來求解參數。首先求梯度,而後使用FTRL進行迭代。

僞代碼思路大體以下

double p = learner.predict(x); //預測
learner.updateModel(x, p, y);  //更新模型
double loss = LogLossEvalutor.calLogLoss(p, y); //計算損失
evalutor.addLogLoss(loss); //更新損失
totalLoss += loss;
trainedNum += 1;

具體實施上Alink有本身的特色和調整。

機器學習都須要迭代訓練,Alink這裏利用了Flink Stream的迭代功能

IterativeStream的實例是經過DataStream的iterate方法建立的˙。iterate方法存在兩個重載形式:

  • 一種是無參的,表示不限定最大等待時間;
  • 一種提供一個長整型maxWaitTimeMillis參數,容許用戶指定等待反饋邊的下一個輸入元素的最大時間間隔。

Alink選擇了第二種。

在建立ConnectedIterativeStreams時候,用迭代流的初始輸入做爲第一個輸入流,用反饋流做爲第二個輸入

每一種數據流(DataStream)都會有與之對應的流轉換(StreamTransformation)。IterativeStream對應的轉換是FeedbackTransformation。

迭代流(IterativeStream)對應的轉換是反饋轉換(FeedbackTransformation),它表示拓撲中的一個反饋點(也即迭代頭)。一個反饋點包含一個輸入邊以及若干個反饋邊,且Flink要求每一個反饋邊的並行度必須跟輸入邊的並行度一致,這一點在往該轉換中加入反饋邊時會進行校驗。

當IterativeStream對象被構造時,FeedbackTransformation的實例會被建立並傳遞給DataStream的構造方法。

迭代的關閉是經過調用IterativeStream的實例方法closeWith來實現的。這個函數指定了某個流將成爲迭代程序的結束,而且這個流將做爲輸入的第二部分(second input)被反饋回迭代。

2.3.2 迭代構建

對於Alink來講,迭代構建代碼是:

// train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label>
// feedback format = Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps>
IterativeStream.ConnectedIterativeStreams<
    Tuple5<Long, Integer, Integer, Vector, Object>,
    Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
    iteration = input.iterate(Long.MAX_VALUE)
    .withFeedbackType(TypeInformation
    .of(new TypeHint<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {}));

// 即iteration是一個 IterativeStream.ConnectedIterativeStreams<...>
2.3.2.1 迭代的輸入

從代碼和註釋能夠看出,迭代的兩種輸入是:

  • train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label>;這種實際上是訓練數據
  • Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps>;這種實際上是反饋數據,就是「迭代的反饋流」做爲這個第二輸入 (second input);
2.3.2.2 迭代的反饋

反饋流的設置是經過調用IterativeStream的實例方法closeWith來實現的。Alink這裏是

DataStream<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
		result = iterativeBody.filter(
            return (t3.f0 > 0 && t3.f2 > 0); // 這裏是省略版本代碼
            );

iteration.closeWith(result);

前面已經提到過,result filter 的判斷是 return (t3.f0 > 0 && t3.f2 > 0)若是知足條件,則說明時間未過時&向量有意義,因此此時應該反饋回去,繼續訓練

反饋流的格式是:

  • Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps>;

2.3.3 迭代體 CalcTask / ReduceTask

迭代體由兩部分構成:CalcTask / ReduceTask。

CalcTask每個實例都擁有初始化模型dataBridge

DataStream iterativeBody = iteration.flatMap(
    new CalcTask(dataBridge, splitInfo, getParams()))
2.3.3.1 迭代初始化

迭代是由 CalcTask.open 函數開始,主要作以下幾件事

  • 設定各類參數,好比
    • 工做task個數,numWorkers = getRuntimeContext().getNumberOfParallelSubtasks();
    • 本task的id,workerId = getRuntimeContext().getIndexOfThisSubtask();
  • 讀取初始化模型
    • List modelRows = DirectReader.directRead(dataBridge);
    • 把Row類型數據轉換爲線性模型 LinearModelData model = new LinearModelDataConverter().load(modelRows);
  • 讀取本task對應的係數 coef[i - startIdx],這裏就是把整個模型切分到numWorkers這麼多的Task中,並行更新
  • 指定本task的開始時間 startTime = System.currentTimeMillis();
2.3.3.2 處理輸入數據

CalcTask.flatMap1主要實現的是FTRL算法中的predict部分(注意,不是FTRL預測)。

解釋:pt=σ(Xt⋅w)是LR的預測函數,求出pt的惟一目的是爲了求出目標函數(在LR中採用交叉熵損失函數做爲目標函數)對參數w的一階導數g,gi=(pt−yt)xi。此步驟一樣適用於FTRL優化其餘目標函數,惟一的不一樣就是求次梯度g(次梯度是左導和右導之間的集合,函數可導--左導等於右導時,次梯度就等於一階梯度)的方法不一樣。

函數的輸入是 "訓練輸入數據",即SplitVector.flatMap的輸出 ----> CalcCalcTask的輸入。輸入數據是一個五元組,其格式爲 train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label>;

有三點須要注意:

  • 是若是是第一次進入,則須要savedFristModel;
  • 這裏是有輸入就處理,而後當即輸出(和flatMap2不一樣,flatMap2有輸入就處理,但不是當即輸出,而是當時間到期了再輸出);
  • predict的實現:((SparseVector)vec).getValues()[i] * coef[indices[i] - startIdx];

你們會說,不對!predict函數應該是 sigmoid = 1.0 / (1.0 + np.exp(-w.dot(x)))。是的,這裏尚未作 sigmoid 操做。當ReduceTask作了聚合以後,會把聚合好的 p 反饋回迭代體,而後在 CalcTask.flatMap2 中才會作 sigmoid 操做

public void flatMap1(Tuple5<Long, Integer, Integer, Vector, Object> value,
                     Collector<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> out) throws Exception {
    if (!savedFristModel) { //第一次進入須要存模型
        out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(),
            new DenseVector(coef), labelValues, -1.0, modelId++));
        savedFristModel = true;
    }
    Long timeStamps = System.currentTimeMillis();
    double wx = 0.0;
    Long sampleId = value.f0;
    Vector vec = value.f3;
    if (vec instanceof SparseVector) {
        int[] indices = ((SparseVector)vec).getIndices();
        // 這裏就是具體的Predict
        for (int i = 0; i < indices.length; ++i) {
            wx += ((SparseVector)vec).getValues()[i] * coef[indices[i] - startIdx];
        }
    } else {
       ......
    }
    //處理了就輸出
    out.collect(Tuple7.of(sampleId, value.f1, value.f2, value.f3, value.f4, wx, timeStamps));
}
2.3.3.3 歸併數據

ReduceTask.flatMap 負責歸併數據。

public static class ReduceTask extends
    RichFlatMapFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>,
        Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> {
    private int parallelism;
    private int[] poses;
    private Map<Long, List<Object>> buffer;
    private Map<Long, List<Tuple2<Integer, DenseVector>>> models = new HashMap<>();
}

flatMap函數大體完成以下功能,即兩種歸併:

  • 爲了輸出模型使用。判斷是否時間過時 if (value.f0 < 0),若是過時,則歸併模型
    • 生成一個List<Tuple2<Integer, DenseVector>> model = models.get(value.f6); 以value.f6,即時間戳爲key,插入到HashMap中。
    • 若是所有收集完成,則向下遊算子輸出模型,而且從HashMap中刪除暫存的模型。
  • 爲了歸併predict使用。歸併每一個CalcTask計算的predict,造成一個 lable y;
    • 用 label y 更新 Tuple7的f5,即Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps> 中的 label,也就是預測的 y。
    • 給每一個下游算子(就是每一個CalcTask了,不過是做爲flatMap2的輸入)發送這個新Tuple7;

當具體用做輸出模型使用時,其變量以下:

models = {HashMap@13258}  size = 1
 {Long@13456} 1 -> {ArrayList@13678}  size = 1
  key = {Long@13456} 1
  value = {ArrayList@13678}  size = 1
   0 = {Tuple2@13698} "(1,0.0 -8.244533295515879E-5 0.0 -1.103997743166529E-4 0.0 -3.336931546279811E-5....."
2.3.3.4 判斷是否反饋

這個 filter result 是用來判斷是否反饋的。這裏t3.f0 是sampleId, t3.f2是subNum。

DataStream<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
    result = iterativeBody.filter(
    new FilterFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {
        @Override
        public boolean filter(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> t3)
            throws Exception {
            // if t3.f0 > 0 && t3.f2 > 0 then feedback
            return (t3.f0 > 0 && t3.f2 > 0);
        }
    });

對於 t3.f0,有兩處代碼會設置爲負值。

  • 會在savedFirstModel 這裏設置一次"-1";即

    if (!savedFristModel) {
    		out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(),
                        new DenseVector(coef), labelValues, -1.0, modelId++));
        savedFristModel = true;
    }
  • 也會在時間過時時候設置爲 "-1"。

    if (System.currentTimeMillis() - startTime > modelSaveTimeInterval) {
        startTime = System.currentTimeMillis();
        out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(),
            new DenseVector(coef), labelValues, -1.0, modelId++));
    }

對於 t3.f2,若是 subNum 大於零,說明在高維向量切分時候,是獲得了有意義的數值。

所以 return (t3.f0 > 0 && t3.f2 > 0) 說明時間未過時&向量有意義,因此此時應該反饋回去,繼續訓練。

2.3.3.5 判斷是否輸出模型

這裏是filter output。

value.f0 < 0 說明時間到期了,應該輸出模型。

DataStream<Row> output = iterativeBody.filter(
    new FilterFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {
        @Override
        public boolean filter(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value) 
        {
            /* if value.f0 small than 0, then output */
            return value.f0 < 0;
        }
    }).flatMap(new WriteModel(labelType, getVectorCol(), featureCols, hasInterceptItem));
2.3.3.6 處理反饋數據/更新參數

CalcTask.flatMap2實際完成的是FTRL算法的其他部分,即更新參數部分。主要邏輯以下:

  • 計算時間間隔 timeInterval = System.currentTimeMillis() - value.f6;
  • 正式計算predict, p = 1 / (1 + Math.exp(-p)); 即sigmoid 操做;
  • 計算梯度 g = (p - label) * values[i] / Math.sqrt(timeInterval); 這裏除以了時間間隔;
  • 更新參數;
  • 輸入。注意,這裏是有輸入就處理,但 不是當即輸出,而是累積參數,當時間到期了再輸出,也就是作到了按期輸出模型;

Logistic Regression 中,sigmoid函數是σ(a) = 1 / (1 + exp(-a)) ,預估 pt = σ(xt . wt), 則 LogLoss 函數是

\[l_t(w_t) = -y_t log(p_t) - (1-y_t)log(1-p_t) \]

直接計算能夠獲得

\[∇l(w) = (σ(w.x_t) - y_t)x_t = (p_t - y_t)x_t \]

具體 LR + FTRL 算法實現以下:

@Override
public void flatMap2(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value,
                     Collector<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> out)
    throws Exception {
    double p = value.f5;
    // 計算時間間隔 
    long timeInterval = System.currentTimeMillis() - value.f6;
    Vector vec = value.f3;

    /* eta */
    // 正式計算predict,以前只是計算了一半,這裏計算後半部,即
    p = 1 / (1 + Math.exp(-p));
    .....

    if (vec instanceof SparseVector) {
        // 這裏是更新參數
        int[] indices = ((SparseVector)vec).getIndices();
        double[] values = ((SparseVector)vec).getValues();

        for (int i = 0; i < indices.length; ++i) {
            // update zParam nParam
            int id = indices[i] - startIdx;
            // values[i]是xi
            // 下面的計算基本和Google僞代碼一致
            double g = (p - label) * values[i] / Math.sqrt(timeInterval);
            double sigma = (Math.sqrt(nParam[id] + g * g) - Math.sqrt(nParam[id])) / alpha;
            zParam[id] += g - sigma * coef[id];
            nParam[id] += g * g;

            // update model coefficient
            if (Math.abs(zParam[id]) <= l1) {
                coef[id] = 0.0;
            } else {
                coef[id] = ((zParam[id] < 0 ? -1 : 1) * l1 - zParam[id])
                    / ((beta + Math.sqrt(nParam[id]) / alpha + l2));
            }
        }
    } else {
      ......
    }

    // 當時間到期了再輸出,即作到了按期輸出模型
    if (System.currentTimeMillis() - startTime > modelSaveTimeInterval) {
        startTime = System.currentTimeMillis();
        out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(),
            new DenseVector(coef), labelValues, -1.0, modelId++));
    }
}

2.4 輸出模型

WriteModel 類實現了輸出模型功能,大體邏輯以下:

  • 生成一個LinearModelData,用訓練好的Tuple7來填充這個 LinearModelData。其中兩個重要點:
    • modelData.coefVector = (DenseVector)value.f3;
    • modelData.labelValues = (Object[])value.f4;
  • 把模型數據轉換成List rows。LinearModelDataConverter().save(modelData, listCollector);
  • 序列化,發送給下游算子。由於模型可能會很大,因此這裏打散以後分佈發送給下游算子
public void flatMap(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value, Collector<Row> out){
  
//輸入value變量打印以下:
value = {Tuple7@13296} 
 f0 = {Long@13306} -1
 f1 = {Integer@13307} 0
 f2 = {Integer@13308} 2
 f3 = {DenseVector@13309} "-0.7383426732137565 0.0 0.0 0.0 1.5885293675862715E-4 -4.834608575902742E-5 0.0 0.0 -6.754208708318647E-5 ......"
  data = {double[30001]@13314} 
 f4 = {Object[2]@13310} 
 f5 = {Double@13311} -1.0
 f6 = {Long@13312} 0  
  
        //生成模型
        LinearModelData modelData = new LinearModelData();
        ......
        modelData.coefVector = (DenseVector)value.f3;
        modelData.labelValues = (Object[])value.f4;

        //把模型數據轉換成List<Row> rows
        RowCollector listCollector = new RowCollector();
        new LinearModelDataConverter().save(modelData, listCollector);
        List<Row> rows = listCollector.getRows();

        for (Row r : rows) {
            int rowSize = r.getArity();
            for (int j = 0; j < rowSize; ++j) {
 							.....
              //序列化
            }
            out.collect(row);
        }

        iter++;
    }
}

0x03 在線預測

預測功能是在 FtrlPredictStreamOp 完成的。

// ftrl predict
FtrlPredictStreamOp predictResult = new FtrlPredictStreamOp(initModel)
        .setVectorCol(vecColName)
        .setPredictionCol("pred")
        .setReservedCols(new String[]{labelColName})
        .setPredictionDetailCol("details")
        .linkFrom(model, featurePipelineModel.transform(splitter.getSideOutput(0)));

從上面代碼咱們能夠看到

  • FtrlPredict 功能一樣須要初始模型 initModel,咱們也是把邏輯迴歸模型賦予它。這樣也是爲了冷啓動,即當FTRL訓練模塊尚未產生模型以前,FTRL預測模塊也是能夠對其輸入數據作預測的。
  • model 是 FtrlTrainStreamOp 的輸出,即 FTRL 的訓練輸出。因此 WriteModel 就直接把輸出傳給了 FtrlPredict功能。
  • splitter.getSideOutput(0) 這裏是前面提到的測試輸入,就是測試數據集。

linkFrom函數完成了業務邏輯,大體功能以下:

  • 使用 inputs[0].getDataStream().flatMap ------> partition ----> map ----> flatMap(new CollectModel()) 獲得了模型 LinearModelData modelstr;
  • 使用 DataStream.connect 把輸入的測試數據集 和 模型 LinearModelData modelstr關聯起來,這樣每一個task都擁有了在線模型 modelstr,就能夠經過 flatMap(new PredictProcess(...) 進行分佈式預測;
  • 使用 setOutputTable 和 LinearModelMapper 把預測結果輸出;

FTRL的預測功能有三個輸入

  • 初始模型 initModel ----->  最後被 PredictProcess.open 加載,做爲冷啓動的預測模型;
  • 測試數據流 -----> 被 PredictProcess.flatMap1處理,進行預測;
  • FTRL訓練階段產生的模型數據流 ----> 被 PredictProcess.flatMap2 處理,進行在線模型更新;

3.1 初始化

構造函數中完成了初始化,即獲取事先訓練好的邏輯迴歸模型。

public FtrlPredictStreamOp(BatchOperator model) {
    super(new Params());
    if (model != null) {
        dataBridge = DirectReader.collect(model);
    } else {
        throw new IllegalArgumentException("Ftrl algo: initial model is null. Please set a valid initial model.");
    }
}

3.2 獲取在線訓練模型

CollectModel完成了 獲取在線訓練模型 功能。

其邏輯主要是:模型被分紅若干塊,其中 (long)inRow.getField(1) 這裏記錄了具體有多少塊。因此 flatMap 函數會把這些塊累積起來,最後組裝成模型,統一發送給下游算子

具體是經過一個 HashMap<> buffers 來完成臨時拼裝/最後組裝的。

public static class CollectModel implements FlatMapFunction<Row, LinearModelData> {

    private Map<Long, List<Row>> buffers = new HashMap<>(0);

    @Override
    public void flatMap(Row inRow, Collector<LinearModelData> out) throws Exception {
      
// 輸入參數以下      
inRow = {Row@13389} "0,19,0,{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"},null"
 fields = {Object[5]@13405} 
  0 = {Long@13406} 0
  1 = {Long@13403} 19
  2 = {Long@13406} 0
  3 = "{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"}"      
" 
        long id = (long)inRow.getField(0);
        Long nTab = (long)inRow.getField(1);

        Row row = new Row(inRow.getArity() - 2);

        for (int i = 0; i < row.getArity(); ++i) {
            row.setField(i, inRow.getField(i + 2));
        }

        if (buffers.containsKey(id) && buffers.get(id).size() == nTab.intValue() - 1) {
            buffers.get(id).add(row);
            // 若是累積完成,則組裝成模型
            LinearModelData ret = new LinearModelDataConverter().load(buffers.get(id));
            buffers.get(id).clear();
            // 發送給下游算子。
            out.collect(ret);
        } else {            
            if (buffers.containsKey(id)) {
                //若是有key。則往list添加。
                buffers.get(id).add(row);
            } else {
                // 若是沒有key,則添加list
                List<Row> buffer = new ArrayList<>(0);
                buffer.add(row);
                buffers.put(id, buffer);
            }
        }
    }
}

//變量相似這種
this = {FtrlPredictStreamOp$CollectModel@13388} 
 buffers = {HashMap@13393}  size = 1
  {Long@13406} 0 -> {ArrayList@13431}  size = 2
   key = {Long@13406} 0
    value = 0
   value = {ArrayList@13431}  size = 2
    0 = {Row@13409} "0,{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"},null"
    1 = {Row@13471} "1048576,{"featureColNames":null,"featureColTypes":null,"coefVector":{"data":[-0.7383426732137549,0.0,0.0,0.0,1.5885293675862704E-4,-4.834608575902738E-5,0.0,0.0,-6.754208708318643E-5,-1.5904172331763155E-4,0.0,-1.315219790338925E-4,0.0,-4.994749246390495E-4,0.0,2.755456604395511E-4,-9.616429481614131E-4,-9.601054004112163E-5,0.0,-1.6679174640370486E-4,0.0,......"

3.3 在線預測

PredictProcess 完成了在線預測功能,LinearModelMapper 是具體預測實現。

public static class PredictProcess extends RichCoFlatMapFunction<Row, LinearModelData, Row> {
    private LinearModelMapper predictor = null;
    private String modelSchemaJson;
    private String dataSchemaJson;
    private Params params;
    private int iter = 0;
    private DataBridge dataBridge;
}

3.3.1 加載預設置模型

其構造函數得到了 FtrlPredictStreamOp 類的 dataBridge,即事先訓練好的邏輯迴歸模型。每個Task都擁有完整的模型。

open函數會加載邏輯迴歸模型。

public void open(Configuration parameters) throws Exception {
    this.predictor = new LinearModelMapper(TableUtil.fromSchemaJson(modelSchemaJson),
        TableUtil.fromSchemaJson(dataSchemaJson), this.params);
    if (dataBridge != null) {
        // read init model
        List<Row> modelRows = DirectReader.directRead(dataBridge);
        LinearModelData model = new LinearModelDataConverter().load(modelRows);
        this.predictor.loadModel(model);
    }
}

3.3.2 在線預測

FtrlPredictStreamOp.flatMap1 函數完成了在線預測。

public void flatMap1(Row row, Collector<Row> collector) throws Exception {
    collector.collect(this.predictor.map(row));
}

調用棧以下:

predictWithProb:157, LinearModelMapper (com.alibaba.alink.operator.common.linear)
predictResultDetail:114, LinearModelMapper (com.alibaba.alink.operator.common.linear)
map:90, RichModelMapper (com.alibaba.alink.common.mapper)
flatMap1:174, FtrlPredictStreamOp$PredictProcess (com.alibaba.alink.operator.stream.onlinelearning)
flatMap1:143, FtrlPredictStreamOp$PredictProcess (com.alibaba.alink.operator.stream.onlinelearning)
processElement1:53, CoStreamFlatMap (org.apache.flink.streaming.api.operators.co)
processRecord1:135, StreamTwoInputProcessor (org.apache.flink.streaming.runtime.io)

具體是經過 LinearModelMapper 完成。

public abstract class RichModelMapper extends ModelMapper {
    public Row map(Row row) throws Exception {
        if (isPredDetail) { 
            // 咱們的示例代碼在這裏
            Tuple2<Object, String> t2 = predictResultDetail(row);
            return this.outputColsHelper.getResultRow(row, Row.of(t2.f0, t2.f1));
        } else {
            return this.outputColsHelper.getResultRow(row, Row.of(predictResult(row)));
        }
    }  
}

預測代碼以下,能夠看出來使用了sigmoid。

/**
 * Predict the label information with the probability of each label.
 */
public Tuple2 <Object, Double[]> predictWithProb(Vector vector) {
   double dotValue = MatVecOp.dot(vector, model.coefVector);
   switch (model.linearModelType) {
      case LR:
      case SVM:
         double prob = sigmoid(dotValue);
         return new Tuple2 <>(dotValue >= 0 ? model.labelValues[0] : model.labelValues[1],
            new Double[] {prob, 1 - prob});
   }
}

3.3.3 在線更新模型

FtrlPredictStreamOp.flatMap2 函數完成了處理在線訓練輸出的模型數據流,在線更新模型。

LinearModelData參數是由CollectModel完成加載而且傳輸出來的。

在模型加載過程當中,是不能預測的,沒有看到相關保護機制。若是我疏漏請你們指出。

public void flatMap2(LinearModelData linearModel, Collector<Row> collector) throws Exception {
    this.predictor.loadModel(linearModel);
}

0x04 問題解答

針對以前咱們提出的問題,如今總結概括以下:

  • 訓練階段和預測階段都有預製模型以應對"冷啓動"嘛?都有預製模型
  • 訓練階段和預測階段是如何關聯起來的?用 linkFrom 直接把訓練階段和預測階段的算子連在一塊兒
  • 如何把訓練出來的模型傳給預測階段?訓練階段用 Flink collector.collect 把模型發給下游算子
  • 輸出模型時候,模型過大怎麼處理?在線訓練會 模型打散 以後分佈發送給下游算子
  • 在線訓練的模型經過什麼機制實現更新?是定時驅動更新嘛?定時更新
  • 預測階段加載模型過程當中,還能夠預測嘛?有沒有機制保證這段時間內也能預測?目前沒有發現相似保護機制
  • 訓練階段中,有哪些階段用到了並行處理?訓練過程當中主要是FTRL算法的"預測predict" 和 "更新參數"兩個部分,以及發送模型
  • 預測階段中,有哪些階段用到了並行處理?預測過程當中主要是分佈式接受模型和分佈式預測
  • 遇到高維向量如何處理?切分開嘛?切分處理

0xFF 參考

【機器學習】邏輯迴歸(很是詳細)

邏輯迴歸(logistics regression)

【機器學習】LR的分佈式(並行化)實現

並行邏輯迴歸

機器學習算法及其並行化討論

Online LR—— FTRL 算法理解

在線優化算法 FTRL 的原理與實現

LR+FTRL算法原理以及工程化實現

Flink流處理之迭代API分析

FTRL公式推導

FTRL論文筆記

在線機器學習FTRL(Follow-the-regularized-Leader)算法介紹

FTRL代碼實現

FTRL實戰之LR+FTRL(代碼採用的稠密數據)

在線學習算法FTRL-Proximal原理

基於FTRL的在線CTR預測算法

CTR預測算法之FTRL-Proximal

各大公司普遍使用的在線學習算法FTRL詳解

在線最優化求解(Online Optimization)之五:FTRL

FOLLOW THE REGULARIZED LEADER (FTRL) 算法總結

相關文章
相關標籤/搜索