Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習算法平臺,是業界首個同時支持批式算法、流式算法的機器學習平臺。本文和上文一塊兒介紹了在線學習算法 FTRL 在Alink中是如何實現的,但願對你們有所幫助。html
書接上回 Alink漫談(十二) :在線學習算法FTRL 之 總體設計 。到目前爲止,已經處理完畢輸入,接下來就是在線訓練。訓練優化的主要目標是找到一個方向,參數朝這個方向移動以後使得損失函數的值可以減少,這個方向每每由一階偏導或者二階偏導各類組合求得。java
爲了讓你們更好理解,咱們再次貼出總體流程圖:算法
在線訓練主要邏輯是:apache
前面說到,FTRL先要訓練出一個邏輯迴歸模型做爲FTRL算法的初始模型,這是爲了系統冷啓動的須要。api
具體邏輯迴歸模型設定/訓練是 :app
// train initial batch model LogisticRegressionTrainBatchOp lr = new LogisticRegressionTrainBatchOp() .setVectorCol(vecColName) .setLabelCol(labelColName) .setWithIntercept(true) .setMaxIter(10); BatchOperator<?> initModel = featurePipelineModel.transform(trainBatchData).link(lr);
訓練好以後,模型信息是DataSet
FtrlTrainStreamOp將initModel做爲初始化參數。分佈式
FtrlTrainStreamOp model = new FtrlTrainStreamOp(initModel)
在FtrlTrainStreamOp構造函數中會加載這個模型;ide
dataBridge = DirectReader.collect(initModel);
具體加載時經過MemoryDataBridge直接獲取初始化模型DataSet中的數據。函數
public MemoryDataBridge generate(BatchOperator batchOperator, Params globalParams) { return new MemoryDataBridge(batchOperator.collect()); }
從前文可知,Alink的FTRL算法設置的特徵向量維度是30000。因此算法第一步就是切分高維度向量,以便分佈式計算。
String vecColName = "vec"; int numHashFeatures = 30000;
首先要獲取切分信息,代碼以下,就是將特徵數目featureSize 除以 並行度parallelism,而後獲得了每一個task對應係數的初始位置。
private static int[] getSplitInfo(int featureSize, boolean hasInterceptItem, int parallelism) { int coefSize = (hasInterceptItem) ? featureSize + 1 : featureSize; int subSize = coefSize / parallelism; int[] poses = new int[parallelism + 1]; int offset = coefSize % parallelism; for (int i = 0; i < offset; ++i) { poses[i + 1] = poses[i] + subSize + 1; } for (int i = offset; i < parallelism; ++i) { poses[i + 1] = poses[i] + subSize; } return poses; } //程序運行時變量以下 featureSize = 30000 hasInterceptItem = true parallelism = 4 coefSize = 30001 subSize = 7500 poses = {int[5]@11660} 0 = 0 1 = 7501 2 = 15001 3 = 22501 4 = 30001 offset = 1
而後根據切分信息對高維向量進行切割。
// Tuple5<SampleId, taskId, numSubVec, SubVec, label> DataStream<Tuple5<Long, Integer, Integer, Vector, Object>> input = initData.flatMap(new SplitVector(splitInfo, hasInterceptItem, vectorSize, vectorTrainIdx, featureIdx, labelIdx)) .partitionCustom(new CustomBlockPartitioner(), 1);
具體切分在SplitVector.flatMap函數完成,結果就是把一個高維度向量分割給各個CalcTask。
代碼摘要以下:
public void flatMap(Row row, Collector<Tuple5<Long, Integer, Integer, Vector, Object>> collector) throws Exception { long sampleId = counter; counter += parallelism; Vector vec; if (vectorTrainIdx == -1) { ..... } else { // 輸入row的第vectorTrainIdx個field就是那個30000大小的係數向量 vec = VectorUtil.getVector(row.getField(vectorTrainIdx)); } if (vec instanceof SparseVector) { Map<Integer, Vector> tmpVec = new HashMap<>(); for (int i = 0; i < indices.length; ++i) { ..... // 此處迭代完成後,tmpVec中就是task number個元素,每個元素是分割好的係數向量。 } for (Integer key : tmpVec.keySet()) { //此處遍歷,給後面全部CalcTask發送五元組數據。 collector.collect(Tuple5.of(sampleId, key, subNum, tmpVec.get(key), row.getField(labelIdx))); } } else { ...... } } }
這個Tuple5.of(sampleId, key, subNum, tmpVec.get(key), row.getField(labelIdx) )就是後面CalcTask的輸入。
此處理論上有如下幾個重點:
預測方法:在每一輪t中,針對特徵樣本xt,以及迭代後(第一次則是給定初值)的模型參數wt,咱們能夠預測該樣本的標記值:pt=σ(wt,xt),其中σ(a)=1/(1+exp(−a))是一個sigmoid函數。
損失函數:對一個特徵樣本xt,其對應的標記爲yt ∈ 0,1,則經過 logistic loss 來做爲損失函數。
迭代公式:咱們的目的是使得損失函數儘量的小,便可以採用極大似然估計來求解參數。首先求梯度,而後使用FTRL進行迭代。
僞代碼思路大體以下
double p = learner.predict(x); //預測 learner.updateModel(x, p, y); //更新模型 double loss = LogLossEvalutor.calLogLoss(p, y); //計算損失 evalutor.addLogLoss(loss); //更新損失 totalLoss += loss; trainedNum += 1;
具體實施上Alink有本身的特色和調整。
機器學習都須要迭代訓練,Alink這裏利用了Flink Stream的迭代功能。
IterativeStream的實例是經過DataStream的iterate方法建立的˙。iterate方法存在兩個重載形式:
Alink選擇了第二種。
在建立ConnectedIterativeStreams時候,用迭代流的初始輸入做爲第一個輸入流,用反饋流做爲第二個輸入。
每一種數據流(DataStream)都會有與之對應的流轉換(StreamTransformation)。IterativeStream對應的轉換是FeedbackTransformation。
迭代流(IterativeStream)對應的轉換是反饋轉換(FeedbackTransformation),它表示拓撲中的一個反饋點(也即迭代頭)。一個反饋點包含一個輸入邊以及若干個反饋邊,且Flink要求每一個反饋邊的並行度必須跟輸入邊的並行度一致,這一點在往該轉換中加入反饋邊時會進行校驗。
當IterativeStream對象被構造時,FeedbackTransformation的實例會被建立並傳遞給DataStream的構造方法。
迭代的關閉是經過調用IterativeStream的實例方法closeWith來實現的。這個函數指定了某個流將成爲迭代程序的結束,而且這個流將做爲輸入的第二部分(second input)被反饋回迭代。
對於Alink來講,迭代構建代碼是:
// train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label> // feedback format = Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps> IterativeStream.ConnectedIterativeStreams< Tuple5<Long, Integer, Integer, Vector, Object>, Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> iteration = input.iterate(Long.MAX_VALUE) .withFeedbackType(TypeInformation .of(new TypeHint<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {})); // 即iteration是一個 IterativeStream.ConnectedIterativeStreams<...>
從代碼和註釋能夠看出,迭代的兩種輸入是:
反饋流的設置是經過調用IterativeStream的實例方法closeWith來實現的。Alink這裏是
DataStream<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> result = iterativeBody.filter( return (t3.f0 > 0 && t3.f2 > 0); // 這裏是省略版本代碼 ); iteration.closeWith(result);
前面已經提到過,result filter 的判斷是 return (t3.f0 > 0 && t3.f2 > 0)
,若是知足條件,則說明時間未過時&向量有意義,因此此時應該反饋回去,繼續訓練。
反饋流的格式是:
迭代體由兩部分構成:CalcTask / ReduceTask。
CalcTask每個實例都擁有初始化模型dataBridge。
DataStream iterativeBody = iteration.flatMap( new CalcTask(dataBridge, splitInfo, getParams()))
迭代是由 CalcTask.open 函數開始,主要作以下幾件事
CalcTask.flatMap1主要實現的是FTRL算法中的predict部分(注意,不是FTRL預測)。
解釋:pt=σ(Xt⋅w)是LR的預測函數,求出pt的惟一目的是爲了求出目標函數(在LR中採用交叉熵損失函數做爲目標函數)對參數w的一階導數g,gi=(pt−yt)xi。此步驟一樣適用於FTRL優化其餘目標函數,惟一的不一樣就是求次梯度g(次梯度是左導和右導之間的集合,函數可導--左導等於右導時,次梯度就等於一階梯度)的方法不一樣。
函數的輸入是 "訓練輸入數據",即SplitVector.flatMap的輸出 ----> CalcCalcTask的輸入
。輸入數據是一個五元組,其格式爲 train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label>;
有三點須要注意:
((SparseVector)vec).getValues()[i] * coef[indices[i] - startIdx];
你們會說,不對!predict函數應該是 sigmoid = 1.0 / (1.0 + np.exp(-w.dot(x)))
。是的,這裏尚未作 sigmoid 操做。當ReduceTask作了聚合以後,會把聚合好的 p 反饋回迭代體,而後在 CalcTask.flatMap2 中才會作 sigmoid 操做。
public void flatMap1(Tuple5<Long, Integer, Integer, Vector, Object> value, Collector<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> out) throws Exception { if (!savedFristModel) { //第一次進入須要存模型 out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(), new DenseVector(coef), labelValues, -1.0, modelId++)); savedFristModel = true; } Long timeStamps = System.currentTimeMillis(); double wx = 0.0; Long sampleId = value.f0; Vector vec = value.f3; if (vec instanceof SparseVector) { int[] indices = ((SparseVector)vec).getIndices(); // 這裏就是具體的Predict for (int i = 0; i < indices.length; ++i) { wx += ((SparseVector)vec).getValues()[i] * coef[indices[i] - startIdx]; } } else { ...... } //處理了就輸出 out.collect(Tuple7.of(sampleId, value.f1, value.f2, value.f3, value.f4, wx, timeStamps)); }
ReduceTask.flatMap 負責歸併數據。
public static class ReduceTask extends RichFlatMapFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>, Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> { private int parallelism; private int[] poses; private Map<Long, List<Object>> buffer; private Map<Long, List<Tuple2<Integer, DenseVector>>> models = new HashMap<>(); }
flatMap函數大體完成以下功能,即兩種歸併:
當具體用做輸出模型使用時,其變量以下:
models = {HashMap@13258} size = 1 {Long@13456} 1 -> {ArrayList@13678} size = 1 key = {Long@13456} 1 value = {ArrayList@13678} size = 1 0 = {Tuple2@13698} "(1,0.0 -8.244533295515879E-5 0.0 -1.103997743166529E-4 0.0 -3.336931546279811E-5....."
這個 filter result 是用來判斷是否反饋的。這裏t3.f0 是sampleId, t3.f2是subNum。
DataStream<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> result = iterativeBody.filter( new FilterFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() { @Override public boolean filter(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> t3) throws Exception { // if t3.f0 > 0 && t3.f2 > 0 then feedback return (t3.f0 > 0 && t3.f2 > 0); } });
對於 t3.f0,有兩處代碼會設置爲負值。
會在savedFirstModel 這裏設置一次"-1";即
if (!savedFristModel) { out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(), new DenseVector(coef), labelValues, -1.0, modelId++)); savedFristModel = true; }
也會在時間過時時候設置爲 "-1"。
if (System.currentTimeMillis() - startTime > modelSaveTimeInterval) { startTime = System.currentTimeMillis(); out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(), new DenseVector(coef), labelValues, -1.0, modelId++)); }
對於 t3.f2,若是 subNum 大於零,說明在高維向量切分時候,是獲得了有意義的數值。
所以 return (t3.f0 > 0 && t3.f2 > 0)
說明時間未過時&向量有意義,因此此時應該反饋回去,繼續訓練。
這裏是filter output。
value.f0 < 0
說明時間到期了,應該輸出模型。
DataStream<Row> output = iterativeBody.filter( new FilterFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() { @Override public boolean filter(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value) { /* if value.f0 small than 0, then output */ return value.f0 < 0; } }).flatMap(new WriteModel(labelType, getVectorCol(), featureCols, hasInterceptItem));
CalcTask.flatMap2實際完成的是FTRL算法的其他部分,即更新參數部分。主要邏輯以下:
在 Logistic Regression 中,sigmoid函數是σ(a) = 1 / (1 + exp(-a)) ,預估 pt = σ(xt . wt), 則 LogLoss 函數是
直接計算能夠獲得
具體 LR + FTRL 算法實現以下:
@Override public void flatMap2(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value, Collector<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> out) throws Exception { double p = value.f5; // 計算時間間隔 long timeInterval = System.currentTimeMillis() - value.f6; Vector vec = value.f3; /* eta */ // 正式計算predict,以前只是計算了一半,這裏計算後半部,即 p = 1 / (1 + Math.exp(-p)); ..... if (vec instanceof SparseVector) { // 這裏是更新參數 int[] indices = ((SparseVector)vec).getIndices(); double[] values = ((SparseVector)vec).getValues(); for (int i = 0; i < indices.length; ++i) { // update zParam nParam int id = indices[i] - startIdx; // values[i]是xi // 下面的計算基本和Google僞代碼一致 double g = (p - label) * values[i] / Math.sqrt(timeInterval); double sigma = (Math.sqrt(nParam[id] + g * g) - Math.sqrt(nParam[id])) / alpha; zParam[id] += g - sigma * coef[id]; nParam[id] += g * g; // update model coefficient if (Math.abs(zParam[id]) <= l1) { coef[id] = 0.0; } else { coef[id] = ((zParam[id] < 0 ? -1 : 1) * l1 - zParam[id]) / ((beta + Math.sqrt(nParam[id]) / alpha + l2)); } } } else { ...... } // 當時間到期了再輸出,即作到了按期輸出模型 if (System.currentTimeMillis() - startTime > modelSaveTimeInterval) { startTime = System.currentTimeMillis(); out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(), new DenseVector(coef), labelValues, -1.0, modelId++)); } }
WriteModel 類實現了輸出模型功能,大體邏輯以下:
public void flatMap(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value, Collector<Row> out){ //輸入value變量打印以下: value = {Tuple7@13296} f0 = {Long@13306} -1 f1 = {Integer@13307} 0 f2 = {Integer@13308} 2 f3 = {DenseVector@13309} "-0.7383426732137565 0.0 0.0 0.0 1.5885293675862715E-4 -4.834608575902742E-5 0.0 0.0 -6.754208708318647E-5 ......" data = {double[30001]@13314} f4 = {Object[2]@13310} f5 = {Double@13311} -1.0 f6 = {Long@13312} 0 //生成模型 LinearModelData modelData = new LinearModelData(); ...... modelData.coefVector = (DenseVector)value.f3; modelData.labelValues = (Object[])value.f4; //把模型數據轉換成List<Row> rows RowCollector listCollector = new RowCollector(); new LinearModelDataConverter().save(modelData, listCollector); List<Row> rows = listCollector.getRows(); for (Row r : rows) { int rowSize = r.getArity(); for (int j = 0; j < rowSize; ++j) { ..... //序列化 } out.collect(row); } iter++; } }
預測功能是在 FtrlPredictStreamOp 完成的。
// ftrl predict FtrlPredictStreamOp predictResult = new FtrlPredictStreamOp(initModel) .setVectorCol(vecColName) .setPredictionCol("pred") .setReservedCols(new String[]{labelColName}) .setPredictionDetailCol("details") .linkFrom(model, featurePipelineModel.transform(splitter.getSideOutput(0)));
從上面代碼咱們能夠看到
linkFrom函數完成了業務邏輯,大體功能以下:
inputs[0].getDataStream().flatMap ------> partition ----> map ----> flatMap(new CollectModel())
獲得了模型 LinearModelData modelstr;flatMap(new PredictProcess(...)
進行分佈式預測;即 FTRL的預測功能有三個輸入:
構造函數中完成了初始化,即獲取事先訓練好的邏輯迴歸模型。
public FtrlPredictStreamOp(BatchOperator model) { super(new Params()); if (model != null) { dataBridge = DirectReader.collect(model); } else { throw new IllegalArgumentException("Ftrl algo: initial model is null. Please set a valid initial model."); } }
CollectModel完成了 獲取在線訓練模型 功能。
其邏輯主要是:模型被分紅若干塊,其中 (long)inRow.getField(1) 這裏記錄了具體有多少塊。因此 flatMap 函數會把這些塊累積起來,最後組裝成模型,統一發送給下游算子。
具體是經過一個 HashMap<> buffers 來完成臨時拼裝/最後組裝的。
public static class CollectModel implements FlatMapFunction<Row, LinearModelData> { private Map<Long, List<Row>> buffers = new HashMap<>(0); @Override public void flatMap(Row inRow, Collector<LinearModelData> out) throws Exception { // 輸入參數以下 inRow = {Row@13389} "0,19,0,{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"},null" fields = {Object[5]@13405} 0 = {Long@13406} 0 1 = {Long@13403} 19 2 = {Long@13406} 0 3 = "{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"}" " long id = (long)inRow.getField(0); Long nTab = (long)inRow.getField(1); Row row = new Row(inRow.getArity() - 2); for (int i = 0; i < row.getArity(); ++i) { row.setField(i, inRow.getField(i + 2)); } if (buffers.containsKey(id) && buffers.get(id).size() == nTab.intValue() - 1) { buffers.get(id).add(row); // 若是累積完成,則組裝成模型 LinearModelData ret = new LinearModelDataConverter().load(buffers.get(id)); buffers.get(id).clear(); // 發送給下游算子。 out.collect(ret); } else { if (buffers.containsKey(id)) { //若是有key。則往list添加。 buffers.get(id).add(row); } else { // 若是沒有key,則添加list List<Row> buffer = new ArrayList<>(0); buffer.add(row); buffers.put(id, buffer); } } } } //變量相似這種 this = {FtrlPredictStreamOp$CollectModel@13388} buffers = {HashMap@13393} size = 1 {Long@13406} 0 -> {ArrayList@13431} size = 2 key = {Long@13406} 0 value = 0 value = {ArrayList@13431} size = 2 0 = {Row@13409} "0,{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"},null" 1 = {Row@13471} "1048576,{"featureColNames":null,"featureColTypes":null,"coefVector":{"data":[-0.7383426732137549,0.0,0.0,0.0,1.5885293675862704E-4,-4.834608575902738E-5,0.0,0.0,-6.754208708318643E-5,-1.5904172331763155E-4,0.0,-1.315219790338925E-4,0.0,-4.994749246390495E-4,0.0,2.755456604395511E-4,-9.616429481614131E-4,-9.601054004112163E-5,0.0,-1.6679174640370486E-4,0.0,......"
PredictProcess 完成了在線預測功能,LinearModelMapper 是具體預測實現。
public static class PredictProcess extends RichCoFlatMapFunction<Row, LinearModelData, Row> { private LinearModelMapper predictor = null; private String modelSchemaJson; private String dataSchemaJson; private Params params; private int iter = 0; private DataBridge dataBridge; }
其構造函數得到了 FtrlPredictStreamOp 類的 dataBridge,即事先訓練好的邏輯迴歸模型。每個Task都擁有完整的模型。
open函數會加載邏輯迴歸模型。
public void open(Configuration parameters) throws Exception { this.predictor = new LinearModelMapper(TableUtil.fromSchemaJson(modelSchemaJson), TableUtil.fromSchemaJson(dataSchemaJson), this.params); if (dataBridge != null) { // read init model List<Row> modelRows = DirectReader.directRead(dataBridge); LinearModelData model = new LinearModelDataConverter().load(modelRows); this.predictor.loadModel(model); } }
FtrlPredictStreamOp.flatMap1 函數完成了在線預測。
public void flatMap1(Row row, Collector<Row> collector) throws Exception { collector.collect(this.predictor.map(row)); }
調用棧以下:
predictWithProb:157, LinearModelMapper (com.alibaba.alink.operator.common.linear) predictResultDetail:114, LinearModelMapper (com.alibaba.alink.operator.common.linear) map:90, RichModelMapper (com.alibaba.alink.common.mapper) flatMap1:174, FtrlPredictStreamOp$PredictProcess (com.alibaba.alink.operator.stream.onlinelearning) flatMap1:143, FtrlPredictStreamOp$PredictProcess (com.alibaba.alink.operator.stream.onlinelearning) processElement1:53, CoStreamFlatMap (org.apache.flink.streaming.api.operators.co) processRecord1:135, StreamTwoInputProcessor (org.apache.flink.streaming.runtime.io)
具體是經過 LinearModelMapper 完成。
public abstract class RichModelMapper extends ModelMapper { public Row map(Row row) throws Exception { if (isPredDetail) { // 咱們的示例代碼在這裏 Tuple2<Object, String> t2 = predictResultDetail(row); return this.outputColsHelper.getResultRow(row, Row.of(t2.f0, t2.f1)); } else { return this.outputColsHelper.getResultRow(row, Row.of(predictResult(row))); } } }
預測代碼以下,能夠看出來使用了sigmoid。
/** * Predict the label information with the probability of each label. */ public Tuple2 <Object, Double[]> predictWithProb(Vector vector) { double dotValue = MatVecOp.dot(vector, model.coefVector); switch (model.linearModelType) { case LR: case SVM: double prob = sigmoid(dotValue); return new Tuple2 <>(dotValue >= 0 ? model.labelValues[0] : model.labelValues[1], new Double[] {prob, 1 - prob}); } }
FtrlPredictStreamOp.flatMap2 函數完成了處理在線訓練輸出的模型數據流,在線更新模型。
LinearModelData參數是由CollectModel完成加載而且傳輸出來的。
在模型加載過程當中,是不能預測的,沒有看到相關保護機制。若是我疏漏請你們指出。
public void flatMap2(LinearModelData linearModel, Collector<Row> collector) throws Exception { this.predictor.loadModel(linearModel); }
針對以前咱們提出的問題,如今總結概括以下:
在線機器學習FTRL(Follow-the-regularized-Leader)算法介紹