spark源碼閱讀--shuffle讀過程源碼分析

shuffle讀過程源碼分析

上一篇中,咱們分析了shuffle在map階段的寫過程。簡單回顧一下,主要是將ShuffleMapTask計算的結果數據在內存中按照分區和key進行排序,過程當中因爲內存限制會溢寫出多個磁盤文件,最後會對全部的文件和內存中剩餘的數據進行歸併排序並溢寫到一個文件中,同時會記錄每一個分區(reduce端分區)的數據在文件中的偏移,而且把分區和偏移的映射關係寫到一個索引文件中。
好了,簡單回顧了寫過程後,咱們不由思考,reduce階段的數據讀取的具體過程是什麼樣的?數據讀取的發生的時機是什麼?java

首先應該回答後一個問題:數據讀取發生的時機是什麼?咱們知道,rdd的計算鏈根據shuffle被切分爲不一樣的stage,一個stage的開始階段通常就是從讀取上一階段的數據開始,也就是說stage讀取數據的過程其實就是reduce過程,而後通過該stage的計算鏈後獲得結果數據,再而後就會把這些數據寫入到磁盤供下一個stage讀取,這個寫入的過程實際上就是map輸出過程,而這個過程咱們以前已經分析過了。本篇咱們要分析的是reduce階段讀取數據的過程。node

囉嗦了這麼一大段,其實就是爲了引出數據讀取的入口,仍是要回到ShuffleMapTask,這裏我只貼部分代碼:數組

// shuffle管理器
  val manager = SparkEnv.get.shuffleManager
  // 獲取一個shuffle寫入器
  writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
  // 這裏能夠看到rdd計算的核心方法就是iterator方法
  // SortShuffleWriter的write方法能夠分爲幾個步驟:
  // 將上游rdd計算出的數據(經過調用rdd.iterator方法)寫入內存緩衝區,
  // 在寫的過程當中若是超過 內存閾值就會溢寫磁盤文件,可能會寫多個文件
  // 最後將溢寫的文件和內存中剩餘的數據一塊兒進行歸併排序後寫入到磁盤中造成一個大的數據文件
  // 這個排序是先按分區排序,在按key排序
  // 在最後歸併排序後寫的過程當中,沒寫一個分區就會手動刷寫一遍,並記錄下這個分區數據在文件中的位移
  // 因此實際上最後寫完一個task的數據後,磁盤上會有兩個文件:數據文件和記錄每一個reduce端partition數據位移的索引文件
  writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
  // 主要是刪除中間過程的溢寫文件,向內存管理器釋放申請的內存
  writer.stop(success = true).get

讀取數據的代碼其實就是rdd.iterator(partition, context),
iterator方法主要是處理rdd緩存的邏輯,若是有緩存就會從緩存中讀取(經過BlockManager),若是沒有緩存就會進行實際的計算,發現最終調用RDD.compute方法進行實際的計算,這個方法是一個抽象方法,是由子類實現的具體的計算邏輯,用戶代碼中對於RDD作的一些變換操做實際上最終都會體如今compute方法中。
另外一方面,咱們知道,map,filter這類算子不是shuffle操做,不會致使stage的劃分,因此咱們想看shuffle讀過程就要找一個Shuffle類型的操做,咱們看一下RDD.groupBy,最終調用了groupByKey方法緩存

RDD.groupByKey

def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope {
// groupByKey shouldn't use map side combine because map side combine does not
// reduce the amount of data shuffled and requires all map side data be inserted
// into a hash table, leading to more objects in the old gen.
val createCombiner = (v: V) => CompactBuffer(v)
val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v
val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2
val bufs = combineByKeyWithClassTag[CompactBuffer[V]](
  createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false)
bufs.asInstanceOf[RDD[(K, Iterable[V])]]
}

最終調用了combineByKeyWithClassTag網絡

RDD.combineByKeyWithClassTag

作一些判斷,檢查一些非法狀況,而後處理一下分區器,最後返回一個ShuffledRDD,因此接下來咱們分析一下ShuffleRDD的compute方法數據結構

def combineByKeyWithClassTag[C](
  createCombiner: V => C,
  mergeValue: (C, V) => C,
  mergeCombiners: (C, C) => C,
  partitioner: Partitioner,
  mapSideCombine: Boolean = true,
  serializer: Serializer = null)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope {
require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0
// 若是key是Array類型,是不支持在map端合併的
// 而且也不支持HashPartitioner
if (keyClass.isArray) {
  if (mapSideCombine) {
    throw new SparkException("Cannot use map-side combining with array keys.")
  }
  if (partitioner.isInstanceOf[HashPartitioner]) {
    throw new SparkException("HashPartitioner cannot partition array keys.")
  }
}
// 聚合器,用於對數據進行聚合
val aggregator = new Aggregator[K, V, C](
  self.context.clean(createCombiner),
  self.context.clean(mergeValue),
  self.context.clean(mergeCombiners))
// 若是分區器相同,就不須要shuffle了
if (self.partitioner == Some(partitioner)) {
  self.mapPartitions(iter => {
    val context = TaskContext.get()
    new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
  }, preservesPartitioning = true)
} else {
  // 返回一個ShuffledRDD
  new ShuffledRDD[K, V, C](self, partitioner)
    .setSerializer(serializer)
    .setAggregator(aggregator)
    .setMapSideCombine(mapSideCombine)
}
}

ShuffleRDD.compute

經過shuffleManager獲取一個讀取器,數據讀取的邏輯在讀取器裏。併發

override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
// 經過shuffleManager獲取一個讀取器
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
  .read()
  .asInstanceOf[Iterator[(K, C)]]
}

SortShuffleManager.getReader

無需多說,直接看BlockStoreShuffleReaderapp

override def getReader[K, C](
  handle: ShuffleHandle,
  startPartition: Int,
  endPartition: Int,
  context: TaskContext): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
  handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}

BlockStoreShuffleReader.read

顯然,這個方法纔是核心所在。總結一下主要步驟:dom

  • 獲取一個包裝的迭代器ShuffleBlockFetcherIterator,它迭代的元素是blockId和這個block對應的讀取流,很顯然這個類就是實現reduce階段數據讀取的關鍵
  • 將原始讀取流轉換成反序列化後的迭代器
  • 將迭代器轉換成可以統計度量值的迭代器,這一系列的轉換和java中對於流的各類裝飾器很相似
  • 將迭代器包裝成可以相應中斷的迭代器。每讀一條數據就會檢查一下任務有沒有被殺死,這種作法是爲了儘可能及時地響應殺死任務的請求,好比從driver端發來殺死任務的消息。
  • 利用聚合器對結果進行聚合。這裏再次利用了AppendonlyMap這個數據結構,前面shuffle寫階段也用到這個數據結構,它的內部是一個以數組做爲底層數據結構的,以線性探測法線性的hash表。
  • 最後對結果進行排序。

因此很顯然,咱們想知道的shuffle讀取數據的具體邏輯就藏在ShuffleBlockFetcherIterator中ide

private[spark] class BlockStoreShuffleReader[K, C](
        handle: BaseShuffleHandle[K, _, C],
        startPartition: Int,
        endPartition: Int,
        context: TaskContext,
        serializerManager: SerializerManager = SparkEnv.get.serializerManager,
        blockManager: BlockManager = SparkEnv.get.blockManager,
        mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
      extends ShuffleReader[K, C] with Logging {
    
      private val dep = handle.dependency
    
      /** Read the combined key-values for this reduce task */
      override def read(): Iterator[Product2[K, C]] = {
        // 獲取一個包裝的迭代器,它迭代的元素是blockId和這個block對應的讀取流
        val wrappedStreams = new ShuffleBlockFetcherIterator(
          context,
          blockManager.shuffleClient,
          blockManager,
          mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
          serializerManager.wrapStream,
          // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
          SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
          SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
          SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
          SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
          SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
    
        val serializerInstance = dep.serializer.newInstance()
    
        // Create a key/value iterator for each stream
        // 將原始讀取流轉換成反序列化後的迭代器
        val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
          // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
          // NextIterator. The NextIterator makes sure that close() is called on the
          // underlying InputStream when all records have been read.
          serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
        }
    
        // Update the context task metrics for each record read.
        val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
        // 轉換成可以統計度量值的迭代器,這一系列的轉換和java中對於流的各類裝飾器很相似
        val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
          recordIter.map { record =>
            readMetrics.incRecordsRead(1)
            record
          },
          context.taskMetrics().mergeShuffleReadMetrics())
    
        // An interruptible iterator must be used here in order to support task cancellation
        // 每讀一條數據就會檢查一下任務有沒有被殺死,
        // 這種作法是爲了儘可能及時地響應殺死任務的請求,好比從driver端發來殺死任務的消息
        val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
    
        val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
          // 利用聚合器對結果進行聚合
          if (dep.mapSideCombine) {
            // We are reading values that are already combined
            val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
            dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
          } else {
            // We don't know the value type, but also don't care -- the dependency *should*
            // have made sure its compatible w/ this aggregator, which will convert the value
            // type to the combined type C
            val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
            dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
          }
        } else {
          require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
          interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
        }
    
        // Sort the output if there is a sort ordering defined.
        // 最後對結果進行排序
        dep.keyOrdering match {
          case Some(keyOrd: Ordering[K]) =>
            // Create an ExternalSorter to sort the data.
            val sorter =
              new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
            sorter.insertAll(aggregatedIter)
            context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
            context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
            context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
            CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
          case None =>
            aggregatedIter
        }
      }
    }

ShuffleBlockFetcherIterator

這個類比較複雜,仔細看在類初始化的代碼中會調用initialize方法。
其次,咱們應該注意它的構造器中的參數,

val wrappedStreams = new ShuffleBlockFetcherIterator(
    context,
    // 若是沒有啓用外部shuffle服務,就是BlockTransferService
    blockManager.shuffleClient,
    blockManager,
    // 經過mapOutputTracker組件獲取每一個分區對應的數據block的物理位置
    mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
    serializerManager.wrapStream,
    // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
    // 獲取幾個配置參數
    SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
    SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
    SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
    SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
    SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

ShuffleBlockFetcherIterator.initialize

  • 首先將本地的block和遠程的block分隔開
  • 而後開始發送請求拉取遠程數據。這個過程當中會有一些約束條件限制拉取數據請求的數量,主要是正在獲取的總數據量的限制,請求併發數限制;每一個遠程地址同時拉取的塊數也會有限制,可是這個閾值默認是Integer.MAX_VALUE
  • 獲取本地的block數據

其中,獲取本地數據較爲簡單,主要就是經過本節點的BlockManager來獲取塊數據,並經過索引文件獲取數據指定分區的數據。
咱們着重分析遠程拉取的部分

private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
// 向TaskContext中添加一個回調,在任務完成時作一些清理工做
context.addTaskCompletionListener(_ => cleanup())

// Split local and remote blocks.
// 將本地的block和遠程的block分隔開
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
assert ((0 == reqsInFlight) == (0 == bytesInFlight),
  "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
  ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)

// Send out initial requests for blocks, up to our maxBytesInFlight
// 發送遠程拉取數據的請求
// 儘量多地發送請求
// 可是會有必定的約束:
// 全局性的約束,全局拉取數據的rpc線程併發數,全局拉取數據的數據量限制
// 每一個遠程地址的限制:每一個遠程地址同時拉取的塊數不能超過必定閾值
fetchUpToMaxBytes()

// 記錄已經發送的請求個數,仍然會有一部分沒有發送請求
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

// Get Local Blocks
// 獲取本地的block數據
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}

ShuffleBlockFetcherIterator.splitLocalRemoteBlocks

咱們首先來看如何切分遠程和本地的數據塊,總結一下這個方法:

  • 首先將同時拉取的數據量的大小除以5做爲每次請求拉取的數據量的限制,這麼作的緣由是爲了容許同時從5個節點拉取數據,由於節點的網絡環境可能並不穩定,同時從多個節點拉取數據有助於減小網絡波動對性能帶來的影響,而對總體的同時拉取數據量的限制主要是爲了限制本機網絡流量的使用
  • 循環遍歷每個節點地址(這裏是BlockManagerId),
  • 若是地址與本機地址相同,那麼對應的blocks就是本地block
  • 對於遠程block,則要根據同時拉取數據量大小的限制將每一個節點的全部block切分紅多個請求(FetchRequest),確保這些請求單次的拉取數據量不會太大

    private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
      // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
      // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
      // nodes, rather than blocking on reading output from one node.
      // 之因此將請求大小減少到maxBytesInFlight / 5,
      // 是爲了並行化地拉取數據,最毒容許同時從5個節點拉取數據
      val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
      logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize
        + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress)
    
      // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
      // at most maxBytesInFlight in order to limit the amount of data in flight.
      val remoteRequests = new ArrayBuffer[FetchRequest]
    
      // Tracks total number of blocks (including zero sized blocks)
      // 記錄總的block數量
      var totalBlocks = 0
      for ((address, blockInfos) <- blocksByAddress) {
        totalBlocks += blockInfos.size
        // 若是地址與本地的BlockManager相同,就是本地block
        if (address.executorId == blockManager.blockManagerId.executorId) {
          // Filter out zero-sized blocks
          localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
          numBlocksToFetch += localBlocks.size
        } else {
          val iterator = blockInfos.iterator
          var curRequestSize = 0L
          var curBlocks = new ArrayBuffer[(BlockId, Long)]
          while (iterator.hasNext) {
            val (blockId, size) = iterator.next()
            // Skip empty blocks
            if (size > 0) {
              curBlocks += ((blockId, size))
              remoteBlocks += blockId
              numBlocksToFetch += 1
              curRequestSize += size
            } else if (size < 0) {
              throw new BlockException(blockId, "Negative block size " + size)
            }
            // 若是超過每次請求的數據量限制,那麼建立一次請求
            if (curRequestSize >= targetRequestSize ||
                curBlocks.size >= maxBlocksInFlightPerAddress) {
              // Add this FetchRequest
              remoteRequests += new FetchRequest(address, curBlocks)
              logDebug(s"Creating fetch request of $curRequestSize at $address "
                + s"with ${curBlocks.size} blocks")
              curBlocks = new ArrayBuffer[(BlockId, Long)]
              curRequestSize = 0
            }
          }
          // Add in the final request
          // 掃尾方法,最後剩餘的塊建立一次請求
          if (curBlocks.nonEmpty) {
            remoteRequests += new FetchRequest(address, curBlocks)
          }
        }
      }
      logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
      remoteRequests
      }

ShuffleBlockFetcherIterator.fetchUpToMaxBytes

回到initialize方法中,在完成本地與遠程block的切分後,咱們獲得了一批封裝好的數據拉取請求,將這些請求加到隊列中,接下來要作的是經過rpc客戶端發送這些請求,

這個方法邏輯仍是相對簡單,主要邏輯就是兩個循環,先發送延緩隊列中的請求,而後發送正常的請求;之因此會有延緩隊列是由於這些請求在第一次待發送時由於數據量超過閾值或者請求數量超過閾值而不能發送,因此就被放到延緩隊列中,而這裏的處理也是優先發送延緩隊列中的請求。每一個請求在發送前必需要知足下面幾個條件纔會被髮送:

  • 當前正在拉取的數據量不能超過閾值maxReqsInFlight(默認48m);這裏會有一個問題,若是某個block的數據量超過maxReqsInFlight值呢?這種狀況下會等當前已經沒有進行中的數據拉取請求才會發送這個請求,由於在對當前請求數據量閾值進行判斷時會檢查bytesInFlight == 0,若是這個條件知足就不會檢查本次請求的數據量是否會超過閾值。
  • 當前正在拉取的請求數據量不能超過閾值(默認Int.MaxValue)
  • 每一個遠程地址的同時請求數量也會有限制(默認Int.MaxValue)
  • 最後符合條件的請求就會被髮送,這裏要提出的一點是若是一次請求的數據量超過maxReqSizeShuffleToMem值,那麼就會寫入磁盤的一個臨時文件中,而這個閾值的默認值是Long.MaxValue,因此默認狀況下是沒有限制的。

    // 發送請求
      // 儘量多地發送請求
      // 可是會有必定的約束:
      // 全局性的約束,全局拉取數據的rpc線程併發數,全局拉取數據的數據量限制
      // 每一個遠程地址的限制:每一個遠程地址同時拉取的塊數不能超過必定閾值
      private def fetchUpToMaxBytes(): Unit = {
      // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
      // immediately, defer the request until the next time it can be processed.
    
      // Process any outstanding deferred fetch requests if possible.
      if (deferredFetchRequests.nonEmpty) {
        for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
          while (isRemoteBlockFetchable(defReqQueue) &&
              !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
            val request = defReqQueue.dequeue()
            logDebug(s"Processing deferred fetch request for $remoteAddress with "
              + s"${request.blocks.length} blocks")
            send(remoteAddress, request)
            if (defReqQueue.isEmpty) {
              deferredFetchRequests -= remoteAddress
            }
          }
        }
      }
    
      // Process any regular fetch requests if possible.
      while (isRemoteBlockFetchable(fetchRequests)) {
        val request = fetchRequests.dequeue()
        val remoteAddress = request.address
        // 若是超過了同時拉取的塊數的限制,那麼將這個請求放到延緩隊列中,留待下次請求
        if (isRemoteAddressMaxedOut(remoteAddress, request)) {
          logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
          val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
          defReqQueue.enqueue(request)
          deferredFetchRequests(remoteAddress) = defReqQueue
        } else {
          send(remoteAddress, request)
        }
      }
    
      // 發送一個請求,而且累加記錄請求的塊的數量,
      // 以用於在下次請求時檢查請求塊的數量是否超過閾值
      def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
        sendRequest(request)
        numBlocksInFlightPerAddress(remoteAddress) =
          numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
      }
    
      // 這個限制是對全部的請求而言,不分具體是哪一個遠程節點
      // 檢查當前的請求的數量是否還有餘量
      // 當前請求的大小是否還有餘量
      // 這主要是爲了限制併發數和網絡流量的使用
      def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
        fetchReqQueue.nonEmpty &&
          (bytesInFlight == 0 ||
            (reqsInFlight + 1 <= maxReqsInFlight &&
              bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
      }
    
      // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a
      // given remote address.
      // 檢測正在拉取的塊的數量是否超過閾值
      // 每一個地址都有一個同事拉取塊數的限制
      def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
        numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
          maxBlocksInFlightPerAddress
      }
      }

ShuffleBlockFetcherIterator.next

經過上一個方法的分析,咱們可以看出來,初始化時發起的拉取數據的請求並未將全部請求所有發送出去,而且還會有請求由於超過閾值而被放入延緩隊列中,那麼這些未發送的請求是何時被再次發送的呢?答案就在next方法中。咱們知道ShuffleBlockFetcherIterator是一個迭代器,因此外部調用者對元素的訪問是經過next方法,因此很容易想到next方法中確定會有發送拉取數據請求的邏輯。
總結一下:

  • 首先從結果隊列中獲取一個拉取成功的結果(結果隊列是一個阻塞隊列,若是沒有拉取成功的結果會阻塞調用者)
  • 拿到一個結果後檢查這個結果是拉取成功仍是拉取失敗,若是失敗則直接拋異常(重試的邏輯實在rpc客戶端實現的,不是在這裏實現)
  • 若是是一個成功的結果,首先要更新一下一些任務度量值,更新一些內部的簿記量,如正在拉取的數據量
  • 將拉取到的字節緩衝包裝成一個字節輸入流
  • 經過外部傳進來的函數對流再包裝一次,經過外部傳進來的函數再包裝一次,通常是解壓縮和解密
  • 並且流被壓縮或者加密過,若是塊的大小比較小,那麼要將這個流拷貝一份,這樣就會實際出發解壓縮和解密,以此來儘早暴露塊損壞的 問題
  • 最後一句關鍵語句,再次發起一輪拉取數據請求的發 送,由於通過next處理以後,已經有拉取成功的數據了,正在拉取的數據量和請求數量可能減少了,這就爲發送新的請求騰出空間

    override def next(): (BlockId, InputStream) = {
      if (!hasNext) {
        throw new NoSuchElementException
      }
    
      numBlocksProcessed += 1
    
      var result: FetchResult = null
      var input: InputStream = null
      // Take the next fetched result and try to decompress it to detect data corruption,
      // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
      // is also corrupt, so the previous stage could be retried.
      // For local shuffle block, throw FailureFetchResult for the first IOException.
      while (result == null) {
        val startFetchWait = System.currentTimeMillis()
        result = results.take()
        val stopFetchWait = System.currentTimeMillis()
        shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
    
        result match {
          case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
            if (address != blockManager.blockManagerId) {
              numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
              // 主要是更新一些度量值
              shuffleMetrics.incRemoteBytesRead(buf.size)
              if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
                shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
              }
              shuffleMetrics.incRemoteBlocksFetched(1)
            }
            bytesInFlight -= size
            if (isNetworkReqDone) {
              reqsInFlight -= 1
              logDebug("Number of requests in flight " + reqsInFlight)
            }
    
            // 將字節緩衝包裝成一個字節輸入流
            val in = try {
              buf.createInputStream()
            } catch {
              // The exception could only be throwed by local shuffle block
              case e: IOException =>
                assert(buf.isInstanceOf[FileSegmentManagedBuffer])
                logError("Failed to create input stream from local block", e)
                buf.release()
                throwFetchFailedException(blockId, address, e)
            }
    
            // 經過外部傳進來的函數再包裝一次,通常是增長壓縮和加密的功能
            input = streamWrapper(blockId, in)
            // Only copy the stream if it's wrapped by compression or encryption, also the size of
            // block is small (the decompressed block is smaller than maxBytesInFlight)
            // 若是塊的大小比較小,並且流被壓縮或者加密過,那麼須要將這個流拷貝一份
            if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
              val originalInput = input
              val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
              try {
                // Decompress the whole block at once to detect any corruption, which could increase
                // the memory usage tne potential increase the chance of OOM.
                // TODO: manage the memory used here, and spill it into disk in case of OOM.
                Utils.copyStream(input, out)
                out.close()
                input = out.toChunkedByteBuffer.toInputStream(dispose = true)
              } catch {
                case e: IOException =>
                  buf.release()
                  if (buf.isInstanceOf[FileSegmentManagedBuffer]
                    || corruptedBlocks.contains(blockId)) {
                    throwFetchFailedException(blockId, address, e)
                  } else {
                    logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
                    corruptedBlocks += blockId
                    fetchRequests += FetchRequest(address, Array((blockId, size)))
                    result = null
                  }
              } finally {
                // TODO: release the buf here to free memory earlier
                originalInput.close()
                in.close()
              }
            }
    
            // 拉取失敗,拋異常
            // 這裏思考一下:拉取塊數據確定是有重試機制的,可是這裏拉取失敗以後直接拋異常是爲什麼??
            // 答案是:重試機制並非正在這裏實現 的,而是在rpc客戶端發送拉取請求時實現了重試機制
            // 也就是說若是到這裏是失敗的話,說明已經通過重試後仍是失敗的,因此這裏直接拋異常就好了
          case FailureFetchResult(blockId, address, e) =>
            throwFetchFailedException(blockId, address, e)
        }
    
        // Send fetch requests up to maxBytesInFlight
        // 這裏再次發送拉取請求,由於前面已經有成功拉取到的數據,
        // 因此正在拉取中的數據量就會減少,因此就能爲新的請求騰出空間
        fetchUpToMaxBytes()
      }
    
      currentResult = result.asInstanceOf[SuccessFetchResult]
      (currentResult.blockId, new BufferReleasingInputStream(input, this))
      }

總結

到此,咱們就把shuffle讀的過程大概分析完了。總體下來,感受主幹邏輯不是很複雜,可是裏面有不少細碎邏輯,因此上面的分析仍是比較碎,這裏把整個過程的主幹邏輯再提煉一下,以便能有個總體的認識:

  • 首先,在一些shuffle類型的RDD中,它的計算方法compute會經過ShuffleManager獲取一個block數據讀取器BlockStoreShuffleReader
  • 經過BlockStoreShuffleReader中的read方法進行數據的讀取,一個reduce端分區的數據通常會依賴於全部的map端輸出的分區數據,因此數據通常會在多個executor(注意是executor節點,經過BlockManagerId惟一標識,一個物理節點可能會運行多個executor節點)節點上,並且每一個executor節點也可能會有多個block,在shuffle寫過程的分析中咱們也提到,每一個map最後時輸出一個數據文件和索引文件,也就是一個block,可是由於一個節點
  • 這個方法經過ShuffleBlockFetcherIterator對象封裝了遠程拉取數據的複雜邏輯,而且最終將拉取到的數據封裝成流的迭代器的形式
  • 對全部的block的流進行層層裝飾,包括反序列化,任務度量值(讀入數據條數)統計,每條數據可中斷,
  • 對數據進行聚合
  • 對聚合後的數據進行排序

因此,從這裏咱們也能看出來,新版的shuffle機制中,也就是SortShuffleManager,用戶代碼對於shuffle以後的rdd拿到的是通過排序的數據(若是指定排序器的話)。

相關文章
相關標籤/搜索