Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習算法平臺,是業界首個同時支持批式算法、流式算法的機器學習平臺。本文將爲你們展示Alink如何劃分訓練數據集和測試數據集。java
兩分法算法
通常作預測分析時,會將數據分爲兩大部分。一部分是訓練數據,用於構建模型,一部分是測試數據,用於檢驗模型。數組
三分法dom
但有時候模型的構建過程當中也須要檢驗模型/輔助模型構建,這時會將訓練數據再分爲兩個部分:1)訓練數據;2)驗證數據(Validation Data)。因此這種狀況下會把數據分爲三部分。機器學習
Training set是用來訓練模型或肯定模型參數的,如ANN中權值等;ide
Validation set是用來作模型選擇(model selection),即作模型的最終優化及肯定,如ANN的結構;函數
Test set則純粹是爲了測試已經訓練好的模型的推廣能力。固然test set並不能保證模型的正確性,他只是說類似的數據用此模型會得出類似的結果。學習
實際應用測試
實際應用中,通常只將數據集分紅兩類,即training set 和test set,大多數文章並不涉及validation set。咱們這裏也不涉及。你們經常使用的sklearn的train_test_split函數就是將矩陣隨機劃分爲訓練子集和測試子集,並返回劃分好的訓練集測試集樣本和訓練集測試集標籤。優化
首先咱們給出示例代碼,而後會深刻剖析:
public class SplitExample { public static void main(String[] args) throws Exception { String url = "iris.csv"; String schema = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string"; //這裏是批處理 BatchOperator data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema); SplitBatchOp spliter = new SplitBatchOp().setFraction(0.8); spliter.linkFrom(data); BatchOperator trainData = spliter; BatchOperator testData = spliter.getSideOutput(0); // 這裏是流處理 CsvSourceStreamOp dataS = new CsvSourceStreamOp().setFilePath(url).setSchemaStr(schema); SplitStreamOp spliterS = new SplitStreamOp().setFraction(0.4); spliterS.linkFrom(dataS); StreamOperator train_data = spliterS; StreamOperator test_data = spliterS.getSideOutput(0); } }
SplitBatchOp是分割批處理的主要類,具體構建DAG的工做是在其linkFrom完成的。
整體思路比較簡單:
numTarget = totCount * fraction
task_n_count * fraction
totSelect = task_1_count * fraction + task_2_count * fraction + ... task_n_count * fraction
numTarget - totSelect
加入到某一個task中。若是要分割數據,首先必須知道數據集的記錄數。好比這個DataSet的記錄是1萬個?仍是十萬個?由於數據集可能會很大,因此這一步操做也使用了並行處理,即把數據分區,而後經過mapPartition操做獲得每個分區上元素的數目。
DataSet<Tuple2<Integer, Long>> countsPerPartition = DataSetUtils.countElementsPerPartition(rows); //返回哪一個task有哪些記錄數 DataSet<long[]> numPickedPerPartition = countsPerPartition .mapPartition(new CountInPartition(fraction)) //計算總數 .setParallelism(1) .name("decide_count_of_each_partition");
由於每一個分區就對應了一個task,因此咱們也能夠認爲,這是獲取了每一個task的記錄數。
具體工做是在 DataSetUtils.countElementsPerPartition 中完成的。返回類型是<index of this subtask, record count in this subtask>,好比3號task擁有30個記錄。
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) { return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() { @Override public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception { long counter = 0; for (T value : values) { counter++; //計算本task的記錄總數 } out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter)); } }); }
計算總數的工做實際上是在下一階段算子中完成的。
接下來的工做主要是在 CountInPartition.mapPartition 完成的,其做用是隨機決定每一個task選擇多少個記錄。
這時候就不須要並行了,因此 .setParallelism(1)
獲得了每一個分區記錄數以後,咱們遍歷每一個task的記錄數,而後累積獲得總記錄數 totCount(就是從上而下計算出來的總數)。
public void mapPartition(Iterable<Tuple2<Integer, Long>> values, Collector<long[]> out) throws Exception { long totCount = 0L; List<Tuple2<Integer, Long>> buffer = new ArrayList<>(); for (Tuple2<Integer, Long> value : values) { //遍歷輸入的全部分區記錄 totCount += value.f1; //f1是Long類型的記錄數 buffer.add(value); } ... //後續代碼在下面分析。 }
而後CountInPartition.mapPartition函數中會隨機決定每一個task會選擇的記錄數。mapPartition的參數 Iterable<Tuple2<Integer, Long>> values 就是前一階段的結果 :一個元祖<task id, 每一個task的記錄數目>。
把這些元祖結合在一塊兒,記錄在buffer這個列表中。
buffer = {ArrayList@8972} size = 4 0 = {Tuple2@8975} "(3,38)" // 3號task,其對應的partition記錄數是38個。 1 = {Tuple2@8976} "(2,0)" 2 = {Tuple2@8977} "(0,38)" 3 = {Tuple2@8978} "(1,74)"
系統的task數目就是buffer大小。
int npart = buffer.size(); // num tasks
而後,根據」記錄總數「計算出來 「隨機訓練數據的個數numTarget」。好比總數1萬,應該隨機分配20%,因而numTarget就應該是2千。這個數字之後會用到。
long numTarget = Math.round((totCount * fraction));
獲得每一個task的記錄數目,好比是上面buffer中的 38,0,38,仍是74,記錄在 eachCount 中。
for (Tuple2<Integer, Long> value : buffer) { eachCount[value.f0] = value.f1; }
獲得每一個task中隨機選中的訓練記錄數,記錄在 eachSelect 中。就是每一個task目前 「記錄數字 * fraction」。好比3號task記錄數是38個,應該選20%,則38*20%=8個。
而後把這些task本身的「隨機訓練記錄數」再累加起來獲得 totSelect(就是從下而上計算出來的總數)。
long totSelect = 0L; for (int i = 0; i < npart; i++) { eachSelect[i] = Math.round(Math.floor(eachCount[i] * fraction)); totSelect += eachSelect[i]; }
請注意,這時候 totSelect 和 以前計算的numTarget就有具體細微出入了,就是理論上的一個數字,可是咱們 從上而下 計算 和 從下而上 計算,其結果可能不同。經過下面咱們能夠看出來。
numTarget = all count * fraction totSelect = task_1_count * fraction + task_2_count * fraction + ...
因此咱們下一步要處理這個細微出入,就獲得remain,這是"整體算出來的隨機數目" numTarget 和 "從全部task選中的隨機訓練記錄數累積" totSelect 的差。
if (totSelect < numTarget) { long remain = numTarget - totSelect; remain = Math.min(remain, totCount - totSelect);
若是恰好個數相等,則就正常分配。
if (remain == totCount - totSelect) {
若是數目不等,隨機決定把"多出來的remain"加入到eachSelect數組中的隨便一個記錄上。
for (int i = 0; i < Math.min(remain, npart); i++) { int taskId = shuffle.get(i); while (eachSelect[taskId] >= eachCount[taskId]) { taskId = (taskId + 1) % npart; } eachSelect[taskId]++; }
最後給出全部信息
long[] statistics = new long[npart * 2]; for (int i = 0; i < npart; i++) { statistics[i] = eachCount[i]; statistics[i + npart] = eachSelect[i]; } out.collect(statistics); // 咱們這裏是4核,因此前面四項是eachCount,後面是eachSelect statistics = {long[8]@9003} 0 = 38 //eachCount 1 = 38 2 = 36 3 = 38 4 = 31 //eachSelect 5 = 31 6 = 28 7 = 30
這些信息是做爲廣播變量存儲起來的,立刻下面就會用到。
.withBroadcastSet(numPickedPerPartition, "counts")
CountInPartition.PickInPartition函數中會隨機在每一個task選擇記錄。
首先獲得task數目 和 以前存儲的廣播變量(就是以前剛剛存儲的)。
int npart = getRuntimeContext().getNumberOfParallelSubtasks(); List<long[]> bc = getRuntimeContext().getBroadcastVariable("counts");
分離count和select。
long[] eachCount = Arrays.copyOfRange(bc.get(0), 0, npart); long[] eachSelect = Arrays.copyOfRange(bc.get(0), npart, npart * 2);
獲得總task數目
int taskId = getRuntimeContext().getIndexOfThisSubtask();
獲得本身 task 對應的 count, select
long count = eachCount[taskId]; long select = eachSelect[taskId];
添加本task對應的記錄,隨機洗牌打亂順序
for (int i = 0; i < count; i++) { shuffle.add(i); //就是把count內的數字加到數組 } Collections.shuffle(shuffle, new Random(taskId)); //洗牌打亂順序 // suffle舉例 shuffle = {ArrayList@8987} size = 38 0 = {Integer@8994} 17 1 = {Integer@8995} 8 2 = {Integer@8996} 33 3 = {Integer@8997} 34 4 = {Integer@8998} 20 5 = {Integer@8999} 0 6 = {Integer@9000} 26 7 = {Integer@9001} 27 8 = {Integer@9002} 23 9 = {Integer@9003} 28 10 = {Integer@9004} 9 11 = {Integer@9005} 16 12 = {Integer@9006} 13 13 = {Integer@9007} 2 14 = {Integer@9008} 5 15 = {Integer@9009} 31 16 = {Integer@9010} 15 17 = {Integer@9011} 22 18 = {Integer@9012} 18 19 = {Integer@9013} 35 20 = {Integer@9014} 36 21 = {Integer@9015} 12 22 = {Integer@9016} 7 23 = {Integer@9017} 21 24 = {Integer@9018} 14 25 = {Integer@9019} 1 26 = {Integer@9020} 10 27 = {Integer@9021} 30 28 = {Integer@9022} 29 29 = {Integer@9023} 19 30 = {Integer@9024} 25 31 = {Integer@9025} 32 32 = {Integer@9026} 37 33 = {Integer@9027} 4 34 = {Integer@9028} 11 35 = {Integer@9029} 6 36 = {Integer@9030} 3 37 = {Integer@9031} 24
隨機選擇,把選擇後的再排序回來
for (int i = 0; i < select; i++) { selected[i] = shuffle.get(i); //這時候select看起來是按照順序選擇,可是實際上suffle裏面已是亂序 } Arrays.sort(selected); //此次再排序 // selected舉例,一共30個 selected = {int[30]@8991} 0 = 0 1 = 1 2 = 2 3 = 5 4 = 7 5 = 8 6 = 9 7 = 10 8 = 12 9 = 13 10 = 14 11 = 15 12 = 16 13 = 17 14 = 18 15 = 19 16 = 20 17 = 21 18 = 22 19 = 23 20 = 26 21 = 27 22 = 28 23 = 29 24 = 30 25 = 31 26 = 33 27 = 34 28 = 35 29 = 36
發送選擇的數據
if (numEmits < selected.length && iRow == selected[numEmits]) { out.collect(row); numEmits++; }
output是訓練數據集,SideOutput是測試數據集。由於這兩個數據集在Alink內部都是Table類型,因此直接使用了SQL算子 minusAll
來完成分割。
this.setOutput(out, in.getSchema()); this.setSideOutputTables(new Table[]{in.getOutputTable().minusAll(this.getOutputTable())});
訓練是在SplitStreamOp類完成的,其經過linkFrom完成了模型的構建。
流處理依賴SplitStream 和 SelectTransformation 這兩個類來完成分割流。具體並無創建一個物理操做,而只是影響了上游算子如何與下游算子聯繫,如何選擇記錄。
SplitStream <Row> splited = in.getDataStream().split(new RandomSelectorOp(getFraction()));
首先,用RandomSelectorOp來隨機決定輸出時候選擇哪一個流。咱們能夠看到,這裏就是隨便起了"a", "b" 這兩個名字而已。
class RandomSelectorOp implements OutputSelector <Row> { private double fraction; private Random random = null; @Override public Iterable <String> select(Row value) { if (null == random) { random = new Random(System.currentTimeMillis()); } List <String> output = new ArrayList <String>(1); output.add((random.nextDouble() < fraction ? "a" : "b")); //隨機選取數字分配,隨意起的名字 return output; } }
其次,獲得那兩個隨機生成的流。
DataStream <Row> partA = splited.select("a"); DataStream <Row> partB = splited.select("b");
最後把這兩個流分別設置爲output和sideOutput。
this.setOutput(partA, in.getSchema()); //訓練集 this.setSideOutputTables(new Table[]{ DataStreamConversionUtil.toTable(getMLEnvironmentId(), partB, in.getSchema())}); //驗證集
最後返回自己,這時候SplitStreamOp擁有兩個成員變量:
this.output就是訓練集。
this.sideOutPut就是驗證集。
return this;