在計算時,爲了節省內存,不把全部的數據一次所有加載到內存中,有一種設計模式叫迭代器模式。設計模式
迭代器模式:在邏輯代碼執行時,真正的邏輯並未執行,而是建立了新的迭代器,新的迭代器保存着對當前迭代器的引用從而造成鏈表,每一個迭代器須要實現hasNext(),next()兩個方法。當觸發計算時,最後一個建立的迭代器會調用next方法,next方法會調用父迭代器的next方法。ide
例如:函數
val list = List("a a", "b d", "c e") val it = list.iterator it.flatMap(_.split(" ")).map((_, 1)).filter(_._1 != "").foreach(println)
這個例子中it是初始迭代器,後面每一個方法都會生成一個新的迭代器,但並不進行迭代計算,到最後foreach方法(相似action算子),開始執行迭代計算了。this
咱們依次展開:spa
def flatMap[B](f: A => GenTraversableOnce[B]): Iterator[B] = new AbstractIterator[B] { // f做用在上游單條數據的結果轉換成的iterator private var cur: Iterator[B] = empty private def nextCur() { cur = f(self.next()).toIterator } def hasNext: Boolean = { while (!cur.hasNext) { if (!self.hasNext) return false nextCur() } true } def next(): B = (if (hasNext) cur else empty).next() }
flatMap方法是建立了一個AbstractIterator的匿名內部類,並實現了hasNext和next兩個方法。每當調用next時,會先調用hasNext,在hasNext中,調上游的iterator的next方法獲取上游這條數據的返回結果,再對這條結果執行用戶傳入的函數f並返回結果後,將其轉換爲iterator,再返回這個iterator的next的結果。線程
def map[B](f: A => B): Iterator[B] = new AbstractIterator[B] { def hasNext = self.hasNext def next() = f(self.next()) }
map與flatMap的代碼模板同樣,邏輯更簡單,只是對上游的next返回結果執行用戶傳入的函數,再返回。設計
def filter(p: A => Boolean): Iterator[A] = new AbstractIterator[A] { private var hd: A = _ private var hdDefined: Boolean = false def hasNext: Boolean = hdDefined || { do { if (!self.hasNext) return false hd = self.next() } while (!p(hd)) hdDefined = true true } def next() = if (hasNext) { hdDefined = false; hd } else empty.next() }
filter中,調用hasNext時,先調用上游iterator的hasNext,若是返回false,那麼直接返回false。若是上游的hasNext返回true,就取出上游的next結果,並將用戶傳入的判斷函數p做用在這個結果上,若爲true,則退出循環,並將hdDefine置爲true;若p的結果爲false,則繼續從上游取下一條數據讓p判斷。code
def foreach[U](f: A => U) { while (hasNext) f(next()) }
遍歷迭代器,將每一個元素傳給用戶傳入的函數f中執行。繼承
在spark的每一個任務中,都是以迭代器模式進行計算的。而每一個迭代器的鏈表對應每一個分區中的數據。RDD的每一個算子會生成一個新的RDD,新的RDD會保存對前一個RDD的引用,而且會保存傳入到算子中的用戶定義函數。ip
例如:
def map[U: ClassTag](f: T => U): RDD[U] = withScope { val cleanF = sc.clean(f) new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF)) }
這個map算子會返回一個MapPartitionsRDD,MapPartitionsRDD中含有當前this這個RDD的引用,並把用戶定義函數f轉換成做用於iterator的函數傳入到MapPartitionsRDD中。
RDD中有個抽象方法compute,MapPartitionsRDD中實現以下:
override def compute(split: Partition, context: TaskContext): Iterator[U] = f(context, split.index, firstParent[T].iterator(split, context))
從父RDD(firstParent[T])獲取迭代器,這個過程須要分區信息split和任務上下文。再map算子中轉換後的用戶定義函數做用在這個迭代器上。
compute方法同迭代器模式相似,也是不斷從上游RDD獲取的迭代器,這樣來得到一個迭代器的鏈表,這個鏈表就是一個task要執行的任務。
爲了說明這個過程,咱們從Executor源碼來找尋。
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { val tr = new TaskRunner(context, taskDescription) runningTasks.put(taskDescription.taskId, tr) threadPool.execute(tr) }
Executor源碼中有個launchTask方法,會建立TaskRunner,將TaskRunner交給線程池執行。TaskRunner是什麼呢?
在Executor源碼中有一個內部類,TaskRunner,它是一個線程的任務:
class TaskRunner( execBackend: ExecutorBackend, private val taskDescription: TaskDescription) extends Runnable {
繼承Runnable必須實現run方法,找到run方法,在run方法中找到了以下代碼:
val res = task.run( taskAttemptId = taskId, attemptNumber = taskDescription.attemptNumber, metricsSystem = env.metricsSystem) threwException = false
點進這裏task的run,會在Task類中找到runTask(context),這個runTask是Task類的抽象方法,會被Task的子類實現。好比ResultTask,這個子類是最後collect類型的action算子出發的任務類。在ResultTask中,runTask方法調用了rdd的iterator方法來獲取iterator,並將用戶定義的方法做用到這個iterator上。
override def runTask(context: TaskContext): U = { // Deserialize the RDD and the func using the broadcast variables. val threadMXBean = ManagementFactory.getThreadMXBean val deserializeStartTime = System.currentTimeMillis() val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime } else 0L val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime } else 0L func(context, rdd.iterator(partition, context)) }
這個rdd的iterator方法會獲取父rdd的迭代器或調用compute方法。
final def iterator(split: Partition, context: TaskContext): Iterator[T] = { if (storageLevel != StorageLevel.NONE) { getOrCompute(split, context) } else { computeOrReadCheckpoint(split, context) } } private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { if (isCheckpointedAndMaterialized) { firstParent[T].iterator(split, context) } else { compute(split, context) } }
spark每一個任務都是由向前依賴串聯起來RDD鏈表生成的iterator鏈表構成的,任務執行由最後的一個iterator的迭代開始,調用上游的迭代器的next,直到迭代到第一個iterator。這樣避免了將全部數據先加載到內存中,而每次計算都只從源頭取一條數據,大大節省了內存。