spark任務在executor端的運行過程分析

CoarseGrainedExecutorBackend

上一篇,咱們主要分析了一次做業的提交過程,嚴格說是在driver端的過程,做業提交以後通過DAGScheduler根據shuffle依賴關係劃分紅多個stage,依次提交每一個stage,將每一個stage建立於分區數相同數量的Task,幷包裝成一個任務集,交給TaskSchedulerImpl進行分配。TaskSchedulerImpl則會根據SchedulerBackEnd提供的計算資源(executor),並考慮任務本地性,黑名單,調度池的調度順序等因素對任務按照round-robin的方式進行分配,並將Task與executor的分配關係包裝成TaskDescription返回給SchedulerBackEnd。而後SchedulerBackEnd就會根據收到的TaskDescription將任務再次序列化以後發送到對應的executor上執行。本篇,咱們就來分析一下Task在executor上的執行過程。後端

任務執行入口Executor.launchTask

首先,咱們知道CoarseGrainedExecutorBackend是yarn模式下的executor的實現類,這時一個rpc服務端,因此咱們根據rpc客戶端也就是CoarseGraineSchedulerBackEnd發送的消息,而後在服務端找處處理對應消息的方法,順藤摸瓜就能找到Task執行的入口。經過上一篇的分析知道發送任務時,CoarseGraineSchedulerBackEnd發送的是一個LaunchTask類型的消息,咱們看一下CoarseGrainedExecutorBackend.receive方法,其中對於LaunchTask消息的處理以下:緩存

case LaunchTask(data) =>
  if (executor == null) {
    exitExecutor(1, "Received LaunchTask command but executor was null")
  } else {
    val taskDesc = TaskDescription.decode(data.value)
    logInfo("Got assigned task " + taskDesc.taskId)
    executor.launchTask(this, taskDesc)
  }

能夠看到,實際上任務時交給內部的Executor對象來處理,實際上Executor對象承擔了executor端的絕大部分邏輯,能夠認爲CoarseGrainedExecutorBackend僅僅是充當rpc消息中轉的角色,充當spark的rpc框架中端點的角色,而實際的任務執行的邏輯則是由Executor對象來完成的。app

Executor概述

咱們先來看一下Executor類的說明:框架

/**
 * Spark executor, backed by a threadpool to run tasks.
 *
 * This can be used with Mesos, YARN, and the standalone scheduler.
 * An internal RPC interface is used for communication with the driver,
 * except in the case of Mesos fine-grained mode.
 */

Executor內部有一個線程池用來運行任務,Mesos, YARN, 和 standalone模式都是用這個類做爲任務運行的邏輯。此外Executor對象持有SparkEnv的引用,以此來使用spark的一些基礎設施,包括rpc引用。
咱們仍是以任務運行爲線索分析這個類的代碼。less

Executor.launchTask

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
  val tr = new TaskRunner(context, taskDescription)
  runningTasks.put(taskDescription.taskId, tr)
  threadPool.execute(tr)
}

這個代碼沒什麼好說的,應該沒人看不懂吧。因此接下來咱們就看一下TaskRunner這個類。
從這個地方也能看出來,在executor端,一個task對應一個線程。jvm

TaskRunner.run

這個方法賊長,沒有一點耐心還真不容易看完。
其中有一些統計量我就不說了,好比任務運行時間統計,cpu耗時統計,gc耗時統計等等,這裏有一點能夠積累的地方是MXBean,cpu,gc耗時都是經過獲取jvm內置的相關的MXBean獲取到的,入口類是ManagementFactory,具體的能夠細看,這裏再也不展開。ide

總結一下這個方法的主要步驟:fetch

  • 首先向driver發送一個更新任務狀態的消息,通知driver這個task處於運行的狀態。
  • 設置任務屬性,更新依賴的文件和jar包,將新的jar包添加到類加載器的尋找路徑中;注意這些信息都是從driver端跟着TaskDescription一塊兒傳過來的。
  • 對任務進行反序列化生成Task對象,根據任務類型多是ShuffleMapTask或者ResultTask
  • 檢查任務有沒有被殺死,若是被殺死則跑一個異常;(driver隨時均可能發送一個殺死任務的消息)
  • 調用Task.run方法執行任務的運行邏輯
  • 任務運行結束後,清除未正常釋放的內存資源和block鎖資源,並在須要的時候打印資源泄漏的告警日誌和拋出異常
  • 再次檢測任務是否被殺死
  • 將任務運行的結果數據序列化
  • 更新一些任務統計量(一些累加器),以及更新度量系統中的相關統計量
  • 收集該任務相關的全部累加器(包括內置的統計量累加器和用戶註冊的累加器)
  • 將累加器數據和任務結果數據封裝成一個對象並在此序列化
  • 檢測序列化後的體積,有兩個閾值:maxResultSize和maxDirectResultSize,若是超過maxResultSize直接丟棄結果,就是不往blockmanager裏面寫數據,這樣driver端在試圖經過blockmanager遠程拉取數據的時候就獲取不到數據,這時driver就知道這個任務的結果數據太大,失敗了;而對於體積超過maxDirectResultSize的狀況,會將任務結果數據經過blockmanager寫到本地內存和磁盤,而後將block信息發送給driver,driver會根據這些信息來這個節點拉取數據;若是體積小於maxDirectResultSize,則直接經過rpc接口將結果數據發送給driver。
  • 最後還會有對任務失敗的各類總異常的處理。ui

    override def run(): Unit = {
    threadId = Thread.currentThread.getId
    Thread.currentThread.setName(threadName)
    // 監控線程運行狀況的MXBean
    val threadMXBean = ManagementFactory.getThreadMXBean
    // 內存管理器
    val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
    // 記錄反序列化的耗時,回憶一下,咱們再spark的UI界面上能夠看到這個統計值,原來就是在這裏統計的
    val deserializeStartTime = System.currentTimeMillis()
    // 統計反序列化的cpu耗時
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime
    } else 0L
    Thread.currentThread.setContextClassLoader(replClassLoader)
    val ser = env.closureSerializer.newInstance()
    logInfo(s"Running $taskName (TID $taskId)")
    // TODO 經過executor後端向driver發送一個任務狀態更新的消息
    execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
    var taskStart: Long = 0
    var taskStartCpu: Long = 0
    // 依然是經過MXBean獲取gc總時長
    startGCTime = computeTotalGcTime()this

    try {
      // Must be set before updateDependencies() is called, in case fetching dependencies
      // requires access to properties contained within (e.g. for access control).
      Executor.taskDeserializationProps.set(taskDescription.properties)
    
      // TODO 更新依賴的文件和jar包,從driver端拉取到本地,並緩存下來
      updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
      // 對任務進行反序列化,這裏卻並無進行耗時統計
      task = ser.deserialize[Task[Any]](
        taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
      // 屬性集合也是從driver端跟隨taskDescription一塊兒發送過來的
      task.localProperties = taskDescription.properties
      // 設置內存管理器
      task.setTaskMemoryManager(taskMemoryManager)
    
      // If this task has been killed before we deserialized it, let's quit now. Otherwise,
      // continue executing the task.
      // 檢查有沒有被殺掉
      val killReason = reasonIfKilled
      if (killReason.isDefined) {
        // Throw an exception rather than returning, because returning within a try{} block
        // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
        // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
        // for the task.
        throw new TaskKilledException(killReason.get)
      }
    
      // The purpose of updating the epoch here is to invalidate executor map output status cache
      // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
      // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
      // we don't need to make any special calls here.
      //
      if (!isLocal) {
        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
        // 更新epoch值和map輸出狀態
        env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
      }
    
      // Run the actual task and measure its runtime.
      // 運行任務並統計運行時間
      taskStart = System.currentTimeMillis()
      // 統計當前線程的cpu耗時
      taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime
      } else 0L
      var threwException = true
      val value = try {
        // 調用task.run方法運行任務
        val res = task.run(
          // 任務id
          taskAttemptId = taskId,
          // 任務的嘗試次數
          attemptNumber = taskDescription.attemptNumber,
          // 度量系統
          metricsSystem = env.metricsSystem)
        threwException = false
        res
      } finally {
        // 釋放關於該任務的全部鎖, 該任務相關的block的讀寫鎖
        val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
        // 清除全部分配給該任務的內存空間
        val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
    
        // 若是threwException爲false,說明任務正常運行完成
        // 在任務正常運行完的前提下若是還可以釋放出內存,
        // 說明在任務正常執行的過程當中沒有正確地釋放使用的內存,也就是發生了內存泄漏
        if (freedMemory > 0 && !threwException) {
          val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
          if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
            throw new SparkException(errMsg)
          } else {
            logWarning(errMsg)
          }
        }
    
        // 這裏對於鎖資源的檢測和內存資源的檢測是相同的邏輯
        // spark做者認爲,具體的任務應該本身負責將申請的資源(包括內存和鎖資源)在使用完後釋放掉,
        // 不能依賴於靠後面的補救措施
        // 若是沒有正常釋放,就發生了資源泄漏
        // 這裏則是對鎖鎖資源泄漏的檢查
        if (releasedLocks.nonEmpty && !threwException) {
          val errMsg =
            s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
              releasedLocks.mkString("[", ", ", "]")
          if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
            throw new SparkException(errMsg)
          } else {
            logInfo(errMsg)
          }
        }
      }
      // 打印拉取異常日誌
      // 代碼執行到這裏說明用戶並無拋拉取異常
      // 可是框架檢測到拉取異常,這說明用戶把拉取異常吞了,這顯然是錯誤的行爲,
      // 所以須要打印一條錯誤日誌提醒用戶
      task.context.fetchFailed.foreach { fetchFailure =>
        // uh-oh.  it appears the user code has caught the fetch-failure without throwing any
        // other exceptions.  Its *possible* this is what the user meant to do (though highly
        // unlikely).  So we will log an error and keep going.
        logError(s"TID ${taskId} completed successfully though internally it encountered " +
          s"unrecoverable fetch failures!  Most likely this means user code is incorrectly " +
          s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
      }
      // 統計任務完成時間
      val taskFinish = System.currentTimeMillis()
      // 統計任務線程佔用的cpu時間
      val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime
      } else 0L
    
      // If the task has been killed, let's fail it.
      // 再次檢測任務是否被殺掉
      task.context.killTaskIfInterrupted()
    
      // 任務結果的序列化
      val resultSer = env.serializer.newInstance()
      val beforeSerialization = System.currentTimeMillis()
      val valueBytes = resultSer.serialize(value)
      val afterSerialization = System.currentTimeMillis()
    
      // Deserialization happens in two parts: first, we deserialize a Task object, which
      // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
      task.metrics.setExecutorDeserializeTime(
        (taskStart - deserializeStartTime) + task.executorDeserializeTime)
      task.metrics.setExecutorDeserializeCpuTime(
        (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
      // We need to subtract Task.run()'s deserialization time to avoid double-counting
      task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
      task.metrics.setExecutorCpuTime(
        (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
      task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
      task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization)
    
      // Expose task metrics using the Dropwizard metrics system.
      // Update task metrics counters
      executorSource.METRIC_CPU_TIME.inc(task.metrics.executorCpuTime)
      executorSource.METRIC_RUN_TIME.inc(task.metrics.executorRunTime)
      executorSource.METRIC_JVM_GC_TIME.inc(task.metrics.jvmGCTime)
      executorSource.METRIC_DESERIALIZE_TIME.inc(task.metrics.executorDeserializeTime)
      executorSource.METRIC_DESERIALIZE_CPU_TIME.inc(task.metrics.executorDeserializeCpuTime)
      executorSource.METRIC_RESULT_SERIALIZE_TIME.inc(task.metrics.resultSerializationTime)
      executorSource.METRIC_SHUFFLE_FETCH_WAIT_TIME
        .inc(task.metrics.shuffleReadMetrics.fetchWaitTime)
      executorSource.METRIC_SHUFFLE_WRITE_TIME.inc(task.metrics.shuffleWriteMetrics.writeTime)
      executorSource.METRIC_SHUFFLE_TOTAL_BYTES_READ
        .inc(task.metrics.shuffleReadMetrics.totalBytesRead)
      executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ
        .inc(task.metrics.shuffleReadMetrics.remoteBytesRead)
      executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ_TO_DISK
        .inc(task.metrics.shuffleReadMetrics.remoteBytesReadToDisk)
      executorSource.METRIC_SHUFFLE_LOCAL_BYTES_READ
        .inc(task.metrics.shuffleReadMetrics.localBytesRead)
      executorSource.METRIC_SHUFFLE_RECORDS_READ
        .inc(task.metrics.shuffleReadMetrics.recordsRead)
      executorSource.METRIC_SHUFFLE_REMOTE_BLOCKS_FETCHED
        .inc(task.metrics.shuffleReadMetrics.remoteBlocksFetched)
      executorSource.METRIC_SHUFFLE_LOCAL_BLOCKS_FETCHED
        .inc(task.metrics.shuffleReadMetrics.localBlocksFetched)
      executorSource.METRIC_SHUFFLE_BYTES_WRITTEN
        .inc(task.metrics.shuffleWriteMetrics.bytesWritten)
      executorSource.METRIC_SHUFFLE_RECORDS_WRITTEN
        .inc(task.metrics.shuffleWriteMetrics.recordsWritten)
      executorSource.METRIC_INPUT_BYTES_READ
        .inc(task.metrics.inputMetrics.bytesRead)
      executorSource.METRIC_INPUT_RECORDS_READ
        .inc(task.metrics.inputMetrics.recordsRead)
      executorSource.METRIC_OUTPUT_BYTES_WRITTEN
        .inc(task.metrics.outputMetrics.bytesWritten)
      executorSource.METRIC_OUTPUT_RECORDS_WRITTEN
        .inc(task.metrics.inputMetrics.recordsRead)
      executorSource.METRIC_RESULT_SIZE.inc(task.metrics.resultSize)
      executorSource.METRIC_DISK_BYTES_SPILLED.inc(task.metrics.diskBytesSpilled)
      executorSource.METRIC_MEMORY_BYTES_SPILLED.inc(task.metrics.memoryBytesSpilled)
    
      // Note: accumulator updates must be collected after TaskMetrics is updated
      // 這裏手機
      val accumUpdates = task.collectAccumulatorUpdates()
      // TODO: do not serialize value twice
      val directResult = new DirectTaskResult(valueBytes, accumUpdates)
      val serializedDirectResult = ser.serialize(directResult)
      val resultSize = serializedDirectResult.limit()
    
      // directSend = sending directly back to the driver
      val serializedResult: ByteBuffer = {
        if (maxResultSize > 0 && resultSize > maxResultSize) {
          logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
            s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
            s"dropping it.")
          ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
        } else if (resultSize > maxDirectResultSize) {
          val blockId = TaskResultBlockId(taskId)
          env.blockManager.putBytes(
            blockId,
            new ChunkedByteBuffer(serializedDirectResult.duplicate()),
            StorageLevel.MEMORY_AND_DISK_SER)
          logInfo(
            s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
          ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
        } else {
          logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
          serializedDirectResult
        }
      }
    
      setTaskFinishedAndClearInterruptStatus()
      execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
    
    } catch {
      case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
        val reason = task.context.fetchFailed.get.toTaskFailedReason
        if (!t.isInstanceOf[FetchFailedException]) {
          // there was a fetch failure in the task, but some user code wrapped that exception
          // and threw something else.  Regardless, we treat it as a fetch failure.
          val fetchFailedCls = classOf[FetchFailedException].getName
          logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " +
            s"failed, but the ${fetchFailedCls} was hidden by another " +
            s"exception.  Spark is handling this like a fetch failure and ignoring the " +
            s"other exception: $t")
        }
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
    
      case t: TaskKilledException =>
        logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
    
      case _: InterruptedException | NonFatal(_) if
          task != null && task.reasonIfKilled.isDefined =>
        val killReason = task.reasonIfKilled.getOrElse("unknown reason")
        logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(
          taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
    
      case CausedBy(cDE: CommitDeniedException) =>
        val reason = cDE.toTaskCommitDeniedReason
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
    
      case t: Throwable =>
        // Attempt to exit cleanly by informing the driver of our failure.
        // If anything goes wrong (or this was a fatal exception), we will delegate to
        // the default uncaught exception handler, which will terminate the Executor.
        logError(s"Exception in $taskName (TID $taskId)", t)
    
        // SPARK-20904: Do not report failure to driver if if happened during shut down. Because
        // libraries may set up shutdown hooks that race with running tasks during shutdown,
        // spurious failures may occur and can result in improper accounting in the driver (e.g.
        // the task failure would not be ignored if the shutdown happened because of premption,
        // instead of an app issue).
        if (!ShutdownHookManager.inShutdown()) {
          // Collect latest accumulator values to report back to the driver
          val accums: Seq[AccumulatorV2[_, _]] =
            if (task != null) {
              task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
              task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
              task.collectAccumulatorUpdates(taskFailed = true)
            } else {
              Seq.empty
            }
    
          val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))
    
          val serializedTaskEndReason = {
            try {
              ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums))
            } catch {
              case _: NotSerializableException =>
                // t is not serializable so just send the stacktrace
                ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums))
            }
          }
          setTaskFinishedAndClearInterruptStatus()
          execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
        } else {
          logInfo("Not reporting error to driver during JVM shutdown.")
        }
    
        // Don't forcibly exit unless the exception was inherently fatal, to avoid
        // stopping other tasks unnecessarily.
        if (!t.isInstanceOf[SparkOutOfMemoryError] && Utils.isFatalError(t)) {
          uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
        }
    } finally {
      runningTasks.remove(taskId)
    }

    }

Task.run

final def run(
  taskAttemptId: Long,
  attemptNumber: Int,
  metricsSystem: MetricsSystem): T = {
SparkEnv.get.blockManager.registerTask(taskAttemptId)
context = new TaskContextImpl(
  stageId,
  stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
  partitionId,
  taskAttemptId,
  attemptNumber,
  taskMemoryManager,
  localProperties,
  // 度量系統就是SparkEnv的度量對象
  metricsSystem,
  metrics)
TaskContext.setTaskContext(context)
// 記錄運行任務的線程
taskThread = Thread.currentThread()

// 主要是更改TaskContext中的任務殺死緣由的標記變量
// 以給線程發一次中斷
if (_reasonIfKilled != null) {
  kill(interruptThread = false, _reasonIfKilled)
}

new CallerContext(
  "TASK",
  SparkEnv.get.conf.get(APP_CALLER_CONTEXT),
  appId,
  appAttemptId,
  jobId,
  Option(stageId),
  Option(stageAttemptId),
  Option(taskAttemptId),
  Option(attemptNumber)).setCurrentContext()

try {
  runTask(context)
} catch {
  case e: Throwable =>
    // Catch all errors; run task failure callbacks, and rethrow the exception.
    try {
      context.markTaskFailed(e)
    } catch {
      case t: Throwable =>
        e.addSuppressed(t)
    }
    context.markTaskCompleted(Some(e))
    throw e
} finally {
  try {
    // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
    // one is no-op.
    context.markTaskCompleted(None)
  } finally {
    try {
      Utils.tryLogNonFatalError {
        // Release memory used by this thread for unrolling blocks
        // 釋放內存快管理器中該任務使用的內存,最終是經過內存管理器來釋放的
        // 實際上就是更新內存管理器內部的一些用於記錄內存使用狀況的簿記量
        // 真正的內存回收確定仍是有gc來完成的
        SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
        SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
          MemoryMode.OFF_HEAP)
        // Notify any tasks waiting for execution memory to be freed to wake up and try to
        // acquire memory again. This makes impossible the scenario where a task sleeps forever
        // because there are no other tasks left to notify it. Since this is safe to do but may
        // not be strictly necessary, we should revisit whether we can remove this in the
        // future.
        val memoryManager = SparkEnv.get.memoryManager
        // 內存釋放以後,須要通知其餘在等待內存資源的 線程
        memoryManager.synchronized { memoryManager.notifyAll() }
      }
    } finally {
      // Though we unset the ThreadLocal here, the context member variable itself is still
      // queried directly in the TaskRunner to check for FetchFailedExceptions.
      TaskContext.unset()
    }
  }
}
}
  • 建立一個TaskContextImpl,並設置到一個ThreadLocal變量中
  • 檢查任務是否被殺死
  • 調用runTask方法執行實際的任務邏輯
  • 最後會釋放在shuffle過程當中申請的用於數據unroll的內存資源

因此,接下來咱們要分析的確定就是runTask方法,而這個方法是個抽象方法,因爲ResultTask很簡單,我就再也不分析了,這裏我重點分析一下ShuffleMapTask。

ShuffleMapTask.runTask

override def runTask(context: TaskContext): MapStatus = {
// Deserialize the RDD using the broadcast variable.
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
  threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
// 反序列化RDD和shuffle, 關鍵的步驟
// 這裏思考rdd和shuffle反序列化時,內部的SparkContext對象是怎麼反序列化的
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
  ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
  threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L

var writer: ShuffleWriter[Any, Any] = null
try {
  // shuffle管理器
  val manager = SparkEnv.get.shuffleManager
  // 獲取一個shuffle寫入器
  writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
  // 這裏能夠看到rdd計算的核心方法就是iterator方法
  // SortShuffleWriter的write方法能夠分爲幾個步驟:
  // 將上游rdd計算出的數據(經過調用rdd.iterator方法)寫入內存緩衝區,
  // 在寫的過程當中若是超過 內存閾值就會溢寫磁盤文件,可能會寫多個文件
  // 最後將溢寫的文件和內存中剩餘的數據一塊兒進行歸併排序後寫入到磁盤中造成一個大的數據文件
  // 這個排序是先按分區排序,在按key排序
  // 在最後歸併排序後寫的過程當中,沒寫一個分區就會手動刷寫一遍,並記錄下這個分區數據在文件中的位移
  // 因此實際上最後寫完一個task的數據後,磁盤上會有兩個文件:數據文件和記錄每一個reduce端partition數據位移的索引文件
  writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
  // 主要是刪除中間過程的溢寫文件,向內存管理器釋放申請的內存
  writer.stop(success = true).get
} catch {
  case e: Exception =>
    try {
      if (writer != null) {
        writer.stop(success = false)
      }
    } catch {
      case e: Exception =>
        log.debug("Could not stop writer", e)
    }
    throw e
}
}

這個方法仍是大概邏輯仍是很簡單的,主要就是經過rdd的iterator方法獲取當前task對應的分區的計算結果(結果一一個迭代器的形式返回)利用shuffleManager經過blockManager寫入到文件block中,而後將block信息傳回driver上報給BlockManagerMaster。
因此實際上重要的步驟有兩個:經過RDD的計算鏈獲取計算結果;將計算結果通過排序和分區寫到文件中。
這裏我先分析第二個步驟。

SortShuffleWriter.write

spark在2.0以後shuffle管理器改爲了排序shuffle管理器,即SortShuffleManager,因此這裏經過SortShuffleManager管理器獲取到的在通常狀況下都是SortShuffleWriter,固然在知足bypass條件(map端不須要combine,而且分區數小於200)的狀況下會使用BypassMergeSortShuffleWriter。

override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
  // map端進行合併的狀況,此時用戶應該提供聚合器和順序
  require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
  new ExternalSorter[K, V, C](
    context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
  // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
  // care whether the keys get sorted in each partition; that will be done on the reduce side
  // if the operation being run is sortByKey.
  new ExternalSorter[K, V, V](
    context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
// 將map數據所有寫入排序器中,
// 這個過程當中可能會生成多個溢寫文件
sorter.insertAll(records)

// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
// mapId就是shuffleMap端RDD的partitionId
// 獲取這個map分區的shuffle輸出文件名
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
// 加一個uuid後綴
val tmp = Utils.tempFileWith(output)
try {
  val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
  // 這一步將溢寫到的磁盤的文件和內存中的數據進行歸併排序,
  // 並溢寫到一個文件中,這一步寫的文件是臨時文件名
  val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
  // 這一步主要是寫入索引文件,使用move方法原子第將臨時索引和臨時數據文件重命名爲正常的文件名
  shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
  // 返回一個狀態對象,包含shuffle服務Id和各個分區數據在文件中的位移
  mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
} finally {
  if (tmp.exists() && !tmp.delete()) {
    logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
  }
}
}

總結一下這個方法的主要邏輯:

  • 首先獲取一個排序器,並檢查是否有map端的合併器
  • 將rdd計算結果數據寫入排序器,過程當中可能會溢寫過個磁盤文件
  • 最後將多個碎小的溢寫文件和內存緩衝區的數據進行歸併排序,寫到一個文件中
  • 將每一個分區數據在文件中的偏移量寫到一個索引文件中,用於reduce階段拉取數據時使用
  • 返回一個MapStatus對象,封裝了當前executor上的blockManager的id和每一個分區在數據文件中的位移量

總結

本篇先分析到這裏。剩下的代碼都是屬於排序器內部的對數據的排序和溢寫文件的邏輯。這部份內容值得寫一篇文章來單獨分析。
總結一下任務在executor端的執行流程:

  • 首先executor端的rpc服務端點收到LaunchTask的消息,並對傳過來的任務數據進行反序列化成TaskDescription
  • 將任務交給Executor對象運行
  • Executor根據傳過來的TaskDescription對象建立一個TaskRunner對象,並放到線程池中運行。這裏的線程池用的是Executors.newCachedThreadPool,空閒是不會有線程在跑
  • TaskRunner對任務進一步反序列化,調用Task.run方法執行任務運行邏輯
  • ShuffleMapTask類型的任務會將rdd計算結果數據通過排序合併以後寫到一個文件中,並寫一個索引文件
  • 任務運行完成後會更新一些任務統計量和度量系統中的一些統計量
  • 最後會根據結果序列化後的大小選擇不一樣的方式將結果傳回driver。
相關文章
相關標籤/搜索