Spark源碼拜讀之RDD的迭代器串聯

1.迭代器模式

在計算時,爲了節省內存,不把全部的數據一次所有加載到內存中,有一種設計模式叫迭代器模式。設計模式

迭代器模式:在邏輯代碼執行時,真正的邏輯並未執行,而是建立了新的迭代器,新的迭代器保存着對當前迭代器的引用從而造成鏈表,每一個迭代器須要實現hasNext(),next()兩個方法。當觸發計算時,最後一個建立的迭代器會調用next方法,next方法會調用父迭代器的next方法。ide

例如:函數

val list = List("a a", "b d", "c e")
val it = list.iterator
it.flatMap(_.split(" ")).map((_, 1)).filter(_._1 != "").foreach(println)

這個例子中it是初始迭代器,後面每一個方法都會生成一個新的迭代器,但並不進行迭代計算,到最後foreach方法(相似action算子),開始執行迭代計算了。this

咱們依次展開:spa

def flatMap[B](f: A => GenTraversableOnce[B]): Iterator[B] = new AbstractIterator[B] {
  // f做用在上游單條數據的結果轉換成的iterator
  private var cur: Iterator[B] = empty
  private def nextCur() { cur = f(self.next()).toIterator }
  def hasNext: Boolean = {
    while (!cur.hasNext) {
      if (!self.hasNext) return false
      nextCur()
    }
    true
  }
  def next(): B = (if (hasNext) cur else empty).next()
}

flatMap方法是建立了一個AbstractIterator的匿名內部類,並實現了hasNext和next兩個方法。每當調用next時,會先調用hasNext,在hasNext中,調上游的iterator的next方法獲取上游這條數據的返回結果,再對這條結果執行用戶傳入的函數f並返回結果後,將其轉換爲iterator,再返回這個iterator的next的結果。線程

def map[B](f: A => B): Iterator[B] = new AbstractIterator[B] {
  def hasNext = self.hasNext
  def next() = f(self.next())
}

map與flatMap的代碼模板同樣,邏輯更簡單,只是對上游的next返回結果執行用戶傳入的函數,再返回。設計

def filter(p: A => Boolean): Iterator[A] = new AbstractIterator[A] {
  private var hd: A = _
  private var hdDefined: Boolean = false

  def hasNext: Boolean = hdDefined || {
    do {
      if (!self.hasNext) return false
      hd = self.next()
    } while (!p(hd))
    hdDefined = true
    true
  }

  def next() = if (hasNext) { hdDefined = false; hd } else empty.next()
}

filter中,調用hasNext時,先調用上游iterator的hasNext,若是返回false,那麼直接返回false。若是上游的hasNext返回true,就取出上游的next結果,並將用戶傳入的判斷函數p做用在這個結果上,若爲true,則退出循環,並將hdDefine置爲true;若p的結果爲false,則繼續從上游取下一條數據讓p判斷。code

def foreach[U](f: A => U) { while (hasNext) f(next()) }

遍歷迭代器,將每一個元素傳給用戶傳入的函數f中執行。繼承

2.RDD串聯

在spark的每一個任務中,都是以迭代器模式進行計算的。而每一個迭代器的鏈表對應每一個分區中的數據。RDD的每一個算子會生成一個新的RDD,新的RDD會保存對前一個RDD的引用,而且會保存傳入到算子中的用戶定義函數。ip

例如:

def map[U: ClassTag](f: T => U): RDD[U] = withScope {
  val cleanF = sc.clean(f)
  new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
}

這個map算子會返回一個MapPartitionsRDD,MapPartitionsRDD中含有當前this這個RDD的引用,並把用戶定義函數f轉換成做用於iterator的函數傳入到MapPartitionsRDD中。

RDD中有個抽象方法compute,MapPartitionsRDD中實現以下:

override def compute(split: Partition, context: TaskContext): Iterator[U] =
  f(context, split.index, firstParent[T].iterator(split, context))

從父RDD(firstParent[T])獲取迭代器,這個過程須要分區信息split和任務上下文。再map算子中轉換後的用戶定義函數做用在這個迭代器上。

compute方法同迭代器模式相似,也是不斷從上游RDD獲取的迭代器,這樣來得到一個迭代器的鏈表,這個鏈表就是一個task要執行的任務。

爲了說明這個過程,咱們從Executor源碼來找尋。

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
  val tr = new TaskRunner(context, taskDescription)
  runningTasks.put(taskDescription.taskId, tr)
  threadPool.execute(tr)
}

Executor源碼中有個launchTask方法,會建立TaskRunner,將TaskRunner交給線程池執行。TaskRunner是什麼呢?

在Executor源碼中有一個內部類,TaskRunner,它是一個線程的任務:

class TaskRunner(
    execBackend: ExecutorBackend,
    private val taskDescription: TaskDescription)
  extends Runnable {

繼承Runnable必須實現run方法,找到run方法,在run方法中找到了以下代碼:

val res = task.run(
  taskAttemptId = taskId,
  attemptNumber = taskDescription.attemptNumber,
  metricsSystem = env.metricsSystem)
threwException = false

點進這裏task的run,會在Task類中找到runTask(context),這個runTask是Task類的抽象方法,會被Task的子類實現。好比ResultTask,這個子類是最後collect類型的action算子出發的任務類。在ResultTask中,runTask方法調用了rdd的iterator方法來獲取iterator,並將用戶定義的方法做用到這個iterator上。

override def runTask(context: TaskContext): U = {
    // Deserialize the RDD and the func using the broadcast variables.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime
    } else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
        ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L

    func(context, rdd.iterator(partition, context))
}

這個rdd的iterator方法會獲取父rdd的迭代器或調用compute方法。

final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
  if (storageLevel != StorageLevel.NONE) {
    getOrCompute(split, context)
  } else {
    computeOrReadCheckpoint(split, context)
  }
}
private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
{
    if (isCheckpointedAndMaterialized) {
        firstParent[T].iterator(split, context)
    } else {
        compute(split, context)
    }
}

小結

spark每一個任務都是由向前依賴串聯起來RDD鏈表生成的iterator鏈表構成的,任務執行由最後的一個iterator的迭代開始,調用上游的迭代器的next,直到迭代到第一個iterator。這樣避免了將全部數據先加載到內存中,而每次計算都只從源頭取一條數據,大大節省了內存。

相關文章
相關標籤/搜索