Spark MLlib 之 aggregate和treeAggregate從原理到應用

在閱讀spark mllib源碼的時候,發現一個出鏡率很高的函數——aggregate和treeAggregate,好比matrix.columnSimilarities()中。爲了好好理解這兩個方法的使用,因而整理了本篇內容。html

因爲treeAggregate是在aggregate基礎上的優化版本,所以先來看看aggregate是什麼.sql

更多內容參考個人大數據學習之路apache

aggregate

先直接看一下代碼例子:app

import org.apache.spark.sql.SparkSession

object AggregateTest {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]").appName("tf-idf").getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    // 建立rdd,並分紅6個分區
    val rdd = spark.sparkContext.parallelize(1 to 12).repartition(6)
    // 輸出每一個分區的內容
    rdd.mapPartitionsWithIndex((index:Int,it:Iterator[Int])=>{
      Array((s" $index : ${it.toList.mkString(",")}")).toIterator
    }).foreach(println)
    // 執行agg
    val res1 = rdd.aggregate(0)(seqOp, combOp)
  }
  // 分區內執行的方法,直接加和
  def seqOp(s1:Int, s2:Int):Int = {
    println("seq: "+s1+":"+s2)
    s1 + s2
  }
  // 在driver端彙總
  def combOp(c1: Int, c2: Int): Int = {
    println("comb: "+c1+":"+c2)
    c1 + c2
  }
}

這段代碼的主要目的就是爲了求和。考慮到spark分區並行計算的特性,在每一個分區獨立加和,最後再彙總加和。函數

過程能夠參考下面的圖片:
學習

首先看一下map階段,即在每一個分區內計算加和。初始狀況如藍色方塊所示,內容爲:大數據

分區號:裏面的內容
如,0分區內的數據爲6和8

當執行seqop時,會說先用初始值0開始遍歷累加,原理相似以下:優化

rdd.mapPartitions((it:Iterator)=>{
    var sum = init_value // 默認爲0
    it.foreach(sum + _)
    sum
})

所以屏幕上會出現下面的內容,因爲分區之間是並行的,因此最後的結果是亂序的:ui

seq: 0:6
seq: 0:1
seq: 0:3
seq: 1:9
seq: 3:10
seq: 0:2
seq: 0:5
seq: 5:7
seq: 12:12
seq: 0:4
seq: 4:11
seq: 6:8

計算完成後,依次遍歷每一個分區結果,進行累加:this

comb: 0:10
comb: 10:13
comb: 23:2
comb: 25:24
comb: 49:15
comb: 64:14

aggregate的源碼也比較簡單:

def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope {
    var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance())
    val cleanSeqOp = sc.clean(seqOp)
    val cleanCombOp = sc.clean(combOp)
    val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
    val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
    sc.runJob(this, aggregatePartition, mergeResult)
    jobResult
  }

treeAggregate

treeAggregate在aggregate的基礎上作了一些優化,由於aggregate是在每一個分區計算完成後,把全部的數據拉倒driver端,進行統一的遍歷合併,這樣若是數據量很大,在driver端可能會OOM。

所以treeAggregate在中間多加了一層合併。

先來看看代碼,沒有任何的變化:

import org.apache.spark.sql.SparkSession

object TreeAggregateTest {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]").appName("tf-idf").getOrCreate()
    spark.sparkContext.setLogLevel("WARN")

    val rdd = spark.sparkContext.parallelize(1 to 12).repartition(6)
    rdd.mapPartitionsWithIndex((index:Int,it:Iterator[Int])=>{
      Array(s" $index : ${it.toList.mkString(",")}").toIterator
    }).foreach(println)

    val res1 = rdd.treeAggregate(0)(seqOp, combOp)
    println(res1)
  }

  def seqOp(s1:Int, s2:Int):Int = {
    println("seq: "+s1+":"+s2)
    s1 + s2
  }

  def combOp(c1: Int, c2: Int): Int = {
    println("comb: "+c1+":"+c2)
    c1 + c2
  }
}

輸出的結果則發生了變化,首先分區內的操做不變:

3 : 3,10
 2 : 2
 0 : 6,8
 1 : 1,9
 4 : 4,11
 5 : 5,7,12
seq: 0:3
seq: 0:6
seq: 3:10
seq: 6:8
seq: 0:2
seq: 0:1
seq: 1:9
seq: 0:4
seq: 4:11
seq: 0:5
seq: 5:7
seq: 12:12
...

在合併的時候發生了 變化:

comb: 10:13
comb: 23:24
comb: 14:2
comb: 16:15
comb: 47:31

配合下面的流程圖,能夠更好的理解:

搭配treeAggregate的源碼來看一下:

def treeAggregate[U: ClassTag](zeroValue: U)(
      seqOp: (U, T) => U,
      combOp: (U, U) => U,
      depth: Int = 2): U = withScope {
    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
    if (partitions.length == 0) {
      Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
    } else {
      // 這裏都沒什麼變化,在分區中遍歷數據累加
      val cleanSeqOp = context.clean(seqOp)
      val cleanCombOp = context.clean(combOp)
      val aggregatePartition =
        (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
      var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))

      // 關鍵是這下面的內容 !!!!
      // 首先得到當前的分區數
      var numPartitions = partiallyAggregated.partitions.length
      // 計算合適的並行度,我這裏至關於6^(1/2),也就是2.4左右,ceill向上取整後變成3.
      // max(3,2)獲得最後的結果爲3。即每一個樹的分枝有3個葉子節點
      val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
      
      // 遍歷分區,經過對scale取模進行合併計算
      // 這裏判斷一下,當前的分區數是否還夠分。若是少於條件值 scale+(p/scale),就中止分區
      while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
        numPartitions /= scale
        val curNumPartitions = numPartitions
        // 從新定義分區id,並按照分區id從新分區,執行合併計算
        partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
          (i, iter) => iter.map((i % curNumPartitions, _))
        }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
      }
      // 最後統計結果
      partiallyAggregated.reduce(cleanCombOp)
    }
  }

spark中的應用

// matrix求類似度
def columnSimilarities(threshold: Double): CoordinateMatrix = {
...              columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma)
}
// 統計每個向量的相關數據,裏面包含了min max 等等不少信息
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
  val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(
    (aggregator, data) => aggregator.add(data),
    (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
  updateNumRows(summary.count)
  summary
}

瞭解了treeAggregate以後,後續就能夠看matrix的並行求解類似度的源碼了!敬請期待吧...

參考

相關文章
相關標籤/搜索