Spark Shuffle 核心組件BlockStoreShuffleReader內核原理深刻剖析-Spark商業環境實戰

本套系列博客從真實商業環境抽取案例進行總結和分享,並給出Spark源碼解讀及商業實戰指導,請持續關注本套博客。版權聲明:本套Spark源碼解讀及商業實戰歸做者(秦凱新)全部,禁止轉載,歡迎學習。node

1 從ShuffeManager講起

一張圖我已經用過屢次了,不要見怪,由於畢竟都是一個主題,有關shuffle的。英文註釋已經很詳細了,這裏簡單介紹一下:算法

  • 目前只有一個實現 SortShuffleManager。
  • SortShuffleManager依賴於ShuffleWriter提供服務,經過ShuffleWriter定義的規範,能夠將MapTask的任務中間結果按照約束的規範持久化到磁盤。
  • SortShuffleManager總共有三個子類, UnsafeShuffleWriter,SortShuffleWriter ,BypassMergeSortShuffleWriter用於Shuffle的寫。
  • SortShuffleManager使用BlockStoreShuffleReader用於Shuffle的讀。
  • SortShuffleManager依賴於ShuffleHandle樣例類,主要仍是負責向Task傳遞Shuffle信息。一個是序列化,一個是肯定什麼時候繞開合併和排序的Shuffle路徑。

官方英文介紹以下:apache

* Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on the 
     * driver and on each executor, based on the spark.shuffle.manager setting. The driver 
     * registers shuffles with it, and executors (or tasks running locally in the driver) can ask * to read and write data.
     
     * NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and
     * boolean isDriver as parameters.
複製代碼

2 BlockStoreShuffleReader 的氣吞山河(就一個read方法)

reduce Task 最最核心的方法就是BlockStoreShuffleReader幹嗎的呢?主要從MapTask輸出的正式的惟一的Block文件中讀取由起始分區和結束分區指定範圍內的數據。開始以前,咱們重點介紹一下成員變量,同時開啓一段英文:緩存

* Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by
 * requesting them from other nodes' block stores。
複製代碼
  • 那麼要想使用,構造器裏面須要封裝什麼呢?網絡

    private[spark] class BlockStoreShuffleReader[K, C](
          handle: BaseShuffleHandle[K, _, C],
          startPartition: Int,
          endPartition: Int,
          context: TaskContext,
          serializerManager: SerializerManager = SparkEnv.get.serializerManager,
          blockManager: BlockManager = SparkEnv.get.blockManager,
          mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
        extends ShuffleReader[K, C] with Logging 
    複製代碼
  • dep :BaseShuffleHandle 經過樣例類傳入的handle.dependency,也即ShuffleDependency架構

  • read() 方法app

  • mapOutputTracker : 即SparkEnv的子組件MapOuputTracker框架

3 read 方法核心思想講解:

  • ShuffleBlockFetcherIterator:用於獲取多個Block的迭代器,說白了就是調用 mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)的返回值 Seq[(BlockManagerId, Seq[(BlockId, Long)])],傳入地址序列後,經由該方法獲取全部請求的數據迭代器。dom

  • getMapSizesByExecutorId:async

    * Called from executors to get the server URIs and output sizes for each shuffle block that
     * needs to be read from a given range of map output partitions (startPartition is included but
     * endPartition is excluded from the range).
    複製代碼
  • 若是指定了dep.mapSideCombine,就會在ExternalOnlyMap中進行聚合,注意這個可不是AppendOnlyMap。ExternalOnlyMap則繼承SizeTrackingAppendOnlyMap,因此沒有排序輸出迭代器的東西,只有聚合和緩衝的功能

  • wrappedStreams 表示獲取的MapTask的Block數據 override def read(): Iterator[Product2[K, C]] = {

    val wrappedStreams = new ShuffleBlockFetcherIterator(
        context,
        blockManager.shuffleClient,
        blockManager,
        mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
        serializerManager.wrapStream,
        // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
        SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
        SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
        SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
        SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
        SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
    
      val serializerInstance = dep.serializer.newInstance()
    
      // Create a key/value iterator for each stream
      val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
        // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
        // NextIterator. The NextIterator makes sure that close() is called on the
        // underlying InputStream when all records have been read.
        serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
      }
    
      // Update the context task metrics for each record read.
      val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
      val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
        recordIter.map { record =>
          readMetrics.incRecordsRead(1)
          record
        },
        context.taskMetrics().mergeShuffleReadMetrics())
    
      // An interruptible iterator must be used here in order to support task cancellation
      val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
    
    
    
      -----------------------------------神來之筆(聚合或緩衝)-----------------------------------
      val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
        if (dep.mapSideCombine) {
          // We are reading values that are already combined
          val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
          dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
        } else {
          // We don't know the value type, but also don't care -- the dependency *should*
          // have made sure its compatible w/ this aggregator, which will convert the value
          // type to the combined type C
          val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
          dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
        }
    --------------------------------------------------------------------------------------------   
        
      } else {
        require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
        interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
      }
    
      // Sort the output if there is a sort ordering defined.
      dep.keyOrdering match {
        case Some(keyOrd: Ordering[K]) =>
          // Create an ExternalSorter to sort the data.
          val sorter =
            new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
      
      
      ===========================若是須要全局排序調用sorter.insertAll=========================
        sorter.insertAll(aggregatedIter)  <= 構造器沒有聚合器傳入,因此使用PartitionedPairBuffer作緩衝
      ========================================================================================
      
      
      
          context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
          context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
          context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
          
          
          
         ====================若是須要全局排序,直接給出排序迭代器=========================
          CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())         <=神來之筆,sorter.iterator直接給出排序迭代器=
         =========================================================================================
    
          
        case None =>
         ====================若是不須要全局排序,直接給出緩衝區數據迭代器=========================
          aggregatedIter
         =========================================================================================
          
      }
    }
    複製代碼

    }

3 深度剖析一下ExternalSorter

  • 構造器變量(aggregator,ordering):

    private[spark] class ExternalSorter[K, V, C](
      context: TaskContext,
      aggregator: Option[Aggregator[K, V, C]] = None,
      partitioner: Option[Partitioner] = None,
      ordering: Option[Ordering[K]] = None,
      serializer: Serializer = SparkEnv.get.serializer)
    extends Spillable[WritablePartitionedPairCollection[K, C]](context.taskMemoryManager())
    with Logging 
    複製代碼
  • insertAll 方法,發現 aggregator.isDefined若沒有定義,則會使用PartitionedPairBuffer作緩衝,另外注意的是插入操做最多隻會聚合。因此插入會很快,由於沒有排序。若要全局排序,就要調用iterator讀數據時纔會全局排序並給出迭代器。內部爲 groupByPartition(destructiveIterator( collection.partitionedDestructiveSortedIterator(Some(keyComparator))))

注意:collection其實爲PartitionedAppendOnlyMap或者爲PartitionedPairBuffer

def insertAll(records: Iterator[Product2[K, V]]): Unit = {
        // TODO: stop combining if we find that the reduction factor isn't high
        val shouldCombine = aggregator.isDefined
    
        if (shouldCombine) {
          // Combine values in-memory first using our AppendOnlyMap
          val mergeValue = aggregator.get.mergeValue
          val createCombiner = aggregator.get.createCombiner
          var kv: Product2[K, V] = null
          val update = (hadValue: Boolean, oldValue: C) => {
            if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
          }
          while (records.hasNext) {
            addElementsRead()
            kv = records.next()
            map.changeValue((getPartition(kv._1), kv._1), update)
            maybeSpillCollection(usingMap = true)
          }
        } else {
          // Stick values into our buffer
          while (records.hasNext) {
            addElementsRead()
            val kv = records.next()
            buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
            maybeSpillCollection(usingMap = false)
          }
        }
      }
複製代碼
  • sorter.iterator :進一步調用ExternalSorter.partitionedIterator方法

    def iterator: Iterator[Product2[K, C]] = {
      isShuffleSort = false
      partitionedIterator.flatMap(pair => pair._2)
    }
    複製代碼
  • ExternalSorter.partitionedIterator:根據ordering.isDefined來最終調用ExternalSorter.groupByPartition, 最終實現了按照分區讀數據時,按照Key排序輸出。

    def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
          val usingMap = aggregator.isDefined
          val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
          if (spills.isEmpty) {
            // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
            // we don't even need to sort by anything other than partition ID
            if (!ordering.isDefined) {
            
           ------------------------------神來之筆(無排序迭代器)----------------------------------
              // The user hasn't requested sorted keys, so only sort by partition ID, not key
              groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)
              
              ))
            } else {
            
             ------------------------------神來之筆(排序並輸出迭代器)--------------------------------
              // We do need to sort by both partition ID and key
              groupByPartition(destructiveIterator(
                collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
            }
          } else {
            // Merge spilled and in-memory data
            merge(spills, destructiveIterator(
              collection.partitionedDestructiveSortedIterator(comparator)))
          }
        }
    複製代碼

3 ShuffleBlockFetcherIterator方法核心思想講解:

ShuffleBlockFetcherIterator會經過splitLocalRemoteBlocks劃分數據的讀取策略:若是在本地有,那麼能夠直接從BlockManager中獲取數據;若是須要從其餘的節點上獲取,那麼須要走網絡。

3.1 重量級成員

localBlocks:本地 remoteBlocks:遠程 results:請求成功或失敗 SuccessFetchResult

/** Local blocks to fetch, excluding zero-sized blocks. */
      localBlocks:private[this] val 
      = new ArrayBuffer[BlockId]()
    
      /** Remote blocks to fetch, excluding zero-sized blocks. */
      private[this] val remoteBlocks = new HashSet[BlockId]()
    
      /**
       * A queue to hold our results. This turns the asynchronous model provided by
       * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
       */
      private[this] val results = new LinkedBlockingQueue[FetchResult]
複製代碼

3.2 initialize() 在ShuffleBlockFetcherIterator初始化時執行

能夠看到核心方法有:

  • splitLocalRemoteBlocks:劃分本地和遠程讀取Block請求 ,本地的放在localBlocks

  • fetchUpToMaxBytes:發送sendRequest請求,遠程拉取數據。

  • fetchLocalBlocks:調用本地的BlockManager來讀取數據。

    private[this] def initialize(): Unit = {
      // Add a task completion callback (called in both success case and failure case) to cleanup.
      context.addTaskCompletionListener(_ => cleanup())
    
      // Split local and remote blocks.
      val remoteRequests = splitLocalRemoteBlocks()
      // Add the remote requests into our queue in a random order
      fetchRequests ++= Utils.randomize(remoteRequests)
      assert ((0 == reqsInFlight) == (0 == bytesInFlight),
        "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
        ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
    
      // Send out initial requests for blocks, up to our maxBytesInFlight
      fetchUpToMaxBytes()
    
      val numFetches = remoteRequests.size - fetchRequests.size
      logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
    
      // Get Local Blocks
      fetchLocalBlocks()
      logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
    }
    複製代碼

3.3 splitLocalRemoteBlocks

private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
    // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
    // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
    // nodes, rather than blocking on reading output from one node.
    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
    logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize
      + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress)

    // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
    // at most maxBytesInFlight in order to limit the amount of data in flight.
    val remoteRequests = new ArrayBuffer[FetchRequest]

    // Tracks total number of blocks (including zero sized blocks)
    var totalBlocks = 0
    for ((address, blockInfos) <- blocksByAddress) {
      totalBlocks += blockInfos.size
      if (address.executorId == blockManager.blockManagerId.executorId) {
        // Filter out zero-sized blocks
        localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
        numBlocksToFetch += localBlocks.size
      } else {
        val iterator = blockInfos.iterator
        var curRequestSize = 0L
        var curBlocks = new ArrayBuffer[(BlockId, Long)]
        while (iterator.hasNext) {
          val (blockId, size) = iterator.next()
          // Skip empty blocks
          if (size > 0) {
            curBlocks += ((blockId, size))
            remoteBlocks += blockId
            numBlocksToFetch += 1
            curRequestSize += size
          } else if (size < 0) {
            throw new BlockException(blockId, "Negative block size " + size)
          }
          if (curRequestSize >= targetRequestSize ||
              curBlocks.size >= maxBlocksInFlightPerAddress) {
            // Add this FetchRequest
            remoteRequests += new FetchRequest(address, curBlocks)
            logDebug(s"Creating fetch request of $curRequestSize at $address "
              + s"with ${curBlocks.size} blocks")
            curBlocks = new ArrayBuffer[(BlockId, Long)]
            curRequestSize = 0
          }
        }
        // Add in the final request
        if (curBlocks.nonEmpty) {
          remoteRequests += new FetchRequest(address, curBlocks)
        }
      }
    }
    logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
    remoteRequests
  }
複製代碼

3.4 fetchUpToMaxBytes

private def fetchUpToMaxBytes(): Unit = {
    // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
    // immediately, defer the request until the next time it can be processed.

    // Process any outstanding deferred fetch requests if possible.
    if (deferredFetchRequests.nonEmpty) {
      for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
        while (isRemoteBlockFetchable(defReqQueue) &&
            !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
          val request = defReqQueue.dequeue()
          logDebug(s"Processing deferred fetch request for $remoteAddress with "
            + s"${request.blocks.length} blocks")
          send(remoteAddress, request)
          if (defReqQueue.isEmpty) {
            deferredFetchRequests -= remoteAddress
          }
        }
      }
    }
複製代碼
  • 發送遠程讀取讀取shuffleClient.fetchBlocks請求讀取數據

    private[this] def sendRequest(req: FetchRequest) {
         logDebug("Sending request for %d blocks (%s) from %s".format(
           req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
         bytesInFlight += req.size
         reqsInFlight += 1
     
         // so we can look up the size of each blockID
         val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
         val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
         val blockIds = req.blocks.map(_._1.toString)
         val address = req.address
     
         val blockFetchingListener = new BlockFetchingListener {
           override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
             // Only add the buffer to results queue if the iterator is not zombie,
             // i.e. cleanup() has not been called yet.
             ShuffleBlockFetcherIterator.this.synchronized {
               if (!isZombie) {
                 // Increment the ref count because we need to pass this to a different thread.
                 // This needs to be released after use.
                 buf.retain()
                 remainingBlocks -= blockId
                 results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
                   remainingBlocks.isEmpty))
                 logDebug("remainingBlocks: " + remainingBlocks)
               }
             }
             logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
           }
     
           override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
             logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
             results.put(new FailureFetchResult(BlockId(blockId), address, e))
           }
         }
     
         // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
         // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
         // the data and write it to file directly.
         if (req.size > maxReqSizeShuffleToMem) {
           shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
             blockFetchingListener, this)
         } else {
           shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
             blockFetchingListener, null)
         }
       }
    複製代碼

3.5 fetchLocalBlocks()

private[this] def fetchLocalBlocks() {
    val iter = localBlocks.iterator
    while (iter.hasNext) {
      val blockId = iter.next()
      try {
        val buf = blockManager.getBlockData(blockId)
        shuffleMetrics.incLocalBlocksFetched(1)
        shuffleMetrics.incLocalBytesRead(buf.size)
        buf.retain()
        results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
      } catch {
        case e: Exception =>
          // If we see an exception, stop immediately.
          logError(s"Error occurred while fetching local blocks", e)
          results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
          return
      }
    }
  }
複製代碼

4.0 總結

  • ShuffleBlockFetcherIterator 依託於BlockStoreShuffleReader,獲得Maptask的全部Block塊數據的迭代器,
  • 寫入ExternalOnlyMap,並進行了緩衝聚合。
  • 而後寫入map或者buffer的緩衝,最後根據使用全局排序,則使用Sorter.iterator獲得最終有序數據。
  • 沒有使用全局排序直接返回ExternalOnlyMap的迭代器。

秦凱新 於深圳

相關文章
相關標籤/搜索