Spark BlockManager源碼分析

Spark BlockManager源碼分析

更多資源分享

前置條件

  • Hadoop版本: Hadoop 2.6.0-cdh5.15.0
  • Spark版本: SPARK 1.6.0-cdh5.15.0
  • JDK.1.8.0_191
  • scala2.10.7

技能標籤

  • Spark Block 折分local 和 remote blocks
  • local Block讀取的數據是 FileSegmentManagedBuffer對象包含//(ShuffleBlockId,Block數據文件偏移量,長度,數據文件位置)
  • remote blocks 往變量results 放對象new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf)

描述

ShuffleRDD讀取數據

  • ShuffleRDD讀到的數據是上一個RDD輸出數據,即ShuffleMapTask的輸出結果(一個分區數據文件和一個分區對應的偏移量文件)
  • 分區數據文件是有序的,先按分區數從小到大排序,再按(key,value)中的key從小到大進行排序
  • 經過BlockStoreShuffleReader.read()對Block數據進行讀取
  • 經過 MapOutputTracker獲得[(BlockManagerId,[(ShuffleBlockId,長度)])],至關於有多少臺機器進行executor運算,就有多少個BlockManagerId(即有多少個BlockManager進行管理),每一個BlockManager管理着多個數據分區(即多個ShuffleBlockId)
  • 經過 ShuffleBlockFetcherIterator()進行本地Block和遠程Block數據進行處理,會獲得一個 (BlockID, InputStream)這樣元素的迭代器,而後進行flatMap()操做,返序列化流對象並轉化爲KeyValueIterator迭代器,經過兩個迭代器就能夠遍歷當前分區的全部輸入數據,因爲Block數據塊默認的大小是128m,也就是每一個BlockID最大數據大小是128m,並且是一次只處理一個BlockID數據,因此此時是不會內存益處的
  • 而後對迭代器中的(key,value)元素進行遍歷,並插入ExternalAppendOnlyMap對象,若是數據知足插入臨時文件的條件,會把內存中的數據插入到臨時文件,並保存DiskMapIterator對象到變量spilledMaps中,就是防止數據太大,內存溢出

本地Block

  • 本地Block是BlockManager負責拉取
  • 即ShuffleMapTask和ResultTask在同一臺機器上運行時,這個時候只須要經過

遠程Block

  • 遠和Block是ShuffleClient負責調用,BlockTransferService負責拉取

  • ResultTask讀取ShuffleMapTask輸出的數據,ResultTask讀取的RDD分區文件分兩種,一種中本臺機器上的數據文件(也就是ResultTask和ShuffleMapTask在同一臺機器上),這時能過BlockManager就能夠讀到對應分區的數據文件

ShuffleRDD

  • ShuffleRDD.compute()方法
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    //SortShuffleManager.getReader()方法反回BlockStoreShuffleReader
    //調用BlockStoreShuffleReader.read()方法
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }
  • BlockStoreShuffleReader.read()方法
  • 讀到當前reduce task 的當前分區的須要合併的key-values
/** Read the combined key-values for this reduce task */
  override def read(): Iterator[Product2[K, C]] = {
    val streamWrapper: (BlockId, InputStream) => InputStream = { (blockId, in) =>
      blockManager.wrapForCompression(blockId,
        CryptoStreamUtils.wrapForEncryption(in, blockManager.conf))
    }

    //經過 ShuffleBlockFetcherIterator()進行本地Block和遠程Block數據進行處理,會獲得一個 (BlockID, InputStream)這樣元素的迭代器,而後進行flatMap()操做,返序列化流對象並轉化爲KeyValueIterator迭代器,經過兩個迭代器就能夠遍歷當前分區的全部輸入數據,因爲Block數據塊默認的大小是128m,也就是每一個BlockID最大數據大小是128m,並且是一次只處理一個BlockID數據,因此此時是不會內存益處的
    val wrappedStreams = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      //經過 MapOutputTracker獲得[(BlockManagerId,[(ShuffleBlockId,長度)])],至關於有多少臺機器進行executor運算,就有多少個BlockManagerId(即有多少個BlockManager進行管理),每一個BlockManager管理着多個數據分區(即多個ShuffleBlockId)
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      streamWrapper,
      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
      SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

    val ser = Serializer.getSerializer(dep.serializer)
    val serializerInstance = ser.newInstance()

    // Create a key/value iterator for each stream
    val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
      // NextIterator. The NextIterator makes sure that close() is called on the
      // underlying InputStream when all records have been read.
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // Update the context task metrics for each record read.
    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map(record => {
        readMetrics.incRecordsRead(1)
        record
      }),
      context.taskMetrics().updateShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // We are reading values that are already combined
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        //而後對迭代器中的(key,value)元素進行遍歷,並插入ExternalAppendOnlyMap對象,若是數據知足插入臨時文件的條件,會把內存中的數據插入到臨時文件,並保存DiskMapIterator對象到變量spilledMaps中,就是防止數據太大,內存溢出
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // We don't know the value type, but also don't care -- the dependency *should*
        // have made sure its compatible w/ this aggregator, which will convert the value
        // type to the combined type C
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // Sort the output if there is a sort ordering defined.
    //如查定義的排序就進行排序操做,若是沒有就不排序
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.internalMetricsToAccumulators(
          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
  }
  • ShuffleBlockFetcherIterator.initialize()
private[this] def initialize(): Unit = {
    // Add a task completion callback (called in both success case and failure case) to cleanup.
    context.addTaskCompletionListener(_ => cleanup())

    // Split local and remote blocks.
    //拆分本地Block 和 遠程Block
    val remoteRequests = splitLocalRemoteBlocks()
    // Add the remote requests into our queue in a random order
    
    //隨機取遠程的
    fetchRequests ++= Utils.randomize(remoteRequests)

    // Send out initial requests for blocks, up to our maxBytesInFlight
    fetchUpToMaxBytes()

    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

    // Get Local Blocks
    //處理本地Block
    fetchLocalBlocks()
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  }

本地Block處理

  • ShuffleBlockFetcherIterator.fetchLocalBlocks()
  • 拉取本地Block
  • 對localBlocks進行迭代,
/**
   * Fetch the local blocks while we are fetching remote blocks. This is ok because
   * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we
   * track in-memory are the ManagedBuffer references themselves.
   */
  private[this] def fetchLocalBlocks() {
    val iter = localBlocks.iterator
    while (iter.hasNext) {
      val blockId = iter.next()
      try {
      //BlockManager.getBlockData 
      //buf等於  (ShuffleBlockId,Block數據文件偏移量,長度,數據文件位置)
        val buf = blockManager.getBlockData(blockId)
        shuffleMetrics.incLocalBlocksFetched(1)
        shuffleMetrics.incLocalBytesRead(buf.size)
        buf.retain()
        results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
      } catch {
        case e: Exception =>
          // If we see an exception, stop immediately.
          logError(s"Error occurred while fetching local blocks", e)
          results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
          return
      }
    }
  }
  • BlockManager.getBlockData
/**
   * Interface to get local block data. Throws an exception if the block cannot be found or
   * cannot be read successfully.
   */
  override def getBlockData(blockId: BlockId): ManagedBuffer = {
    if (blockId.isShuffle) {
    //shuffleManager.shuffleBlockResolver獲得 IndexShuffleBlockResolver
    //調用IndexShuffleBlockResolver.getBlockData
    //獲得 FileSegmentManagedBuffer的對象,該對象存放在當前(ShuffleBlockId,Block數據文件偏移量,長度,數據文件位置)
      shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
    } else {
      val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
        .asInstanceOf[Option[ByteBuffer]]
      if (blockBytesOpt.isDefined) {
        val buffer = blockBytesOpt.get
        new NioManagedBuffer(buffer)
      } else {
        throw new BlockNotFoundException(blockId.toString)
      }
    }
  }
  • IndexShuffleBlockResolver.getBlockData
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
    // The block is actually going to be a range of a single map output file for this map, so
    // find out the consolidated file, then the offset within that from our index
    val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)

    val in = new DataInputStream(new FileInputStream(indexFile))
    try {
      ByteStreams.skipFully(in, blockId.reduceId * 8)
      val offset = in.readLong()
      val nextOffset = in.readLong()
      //(ShuffleBlockId,Block數據文件偏移量,長度,數據文件位置)
      new FileSegmentManagedBuffer(
        transportConf,
        getDataFile(blockId.shuffleId, blockId.mapId),
        offset,
        nextOffset - offset)
    } finally {
      in.close()
    }
  }

遠程Block處理

  • ShuffleBlockFetcherIterator.initialize
private[this] def initialize(): Unit = {
    // Add a task completion callback (called in both success case and failure case) to cleanup.
    context.addTaskCompletionListener(_ => cleanup())

    // Split local and remote blocks.
    // 折分本地Block和遠程Block
    val remoteRequests = splitLocalRemoteBlocks()
    // Add the remote requests into our queue in a random order
    fetchRequests ++= Utils.randomize(remoteRequests)

    // Send out initial requests for blocks, up to our maxBytesInFlight
    fetchUpToMaxBytes()

    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

    // Get Local Blocks
    fetchLocalBlocks()
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  }
  • ShuffleBlockFetcherIterator.splitLocalRemoteBlocks
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
    // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
    // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
    // nodes, rather than blocking on reading output from one node.
    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
    logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)

    // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
    // at most maxBytesInFlight in order to limit the amount of data in flight.
    val remoteRequests = new ArrayBuffer[FetchRequest]

    // Tracks total number of blocks (including zero sized blocks)
    var totalBlocks = 0
    //blocksByAddress是mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)的返回值[(BlockManagerId,[(ShuffleBlockId,長度)])]
    
    for ((address, blockInfos) <- blocksByAddress) {
      totalBlocks += blockInfos.size
      if (address.executorId == blockManager.blockManagerId.executorId) {
        // Filter out zero-sized blocks
        //把本地Block,放到變量localBlocks   過濾,長度不爲空的,就是有數據的Block,元素的表現爲:   (ShuffleBlockId,長度)])
        localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
        numBlocksToFetch += localBlocks.size
      } else {
        val iterator = blockInfos.iterator
        var curRequestSize = 0L
        var curBlocks = new ArrayBuffer[(BlockId, Long)]
        while (iterator.hasNext) {
          val (blockId, size) = iterator.next()
          // Skip empty blocks
          if (size > 0) {
            curBlocks += ((blockId, size))
            remoteBlocks += blockId
            numBlocksToFetch += 1
            curRequestSize += size
          } else if (size < 0) {
            throw new BlockException(blockId, "Negative block size " + size)
          }
          if (curRequestSize >= targetRequestSize) {
            // Add this FetchRequest
            remoteRequests += new FetchRequest(address, curBlocks)
            curBlocks = new ArrayBuffer[(BlockId, Long)]
            logDebug(s"Creating fetch request of $curRequestSize at $address")
            curRequestSize = 0
          }
        }
        // Add in the final request
        if (curBlocks.nonEmpty) {
         //把遠程Block,放到變量remoteRequests
          // var curBlocks = new ArrayBuffer[(BlockId, Long)]
          remoteRequests += new FetchRequest(address, curBlocks)
        }
      }
    }
    logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
    remoteRequests
  }
  • ShuffleBlockFetcherIterator.fetchUpToMaxBytes
private def fetchUpToMaxBytes(): Unit = {
    // Send fetch requests up to maxBytesInFlight
    while (fetchRequests.nonEmpty &&
      (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
      //把fetchRequests.中的第一個元素,做爲請求參數調用sendRequest()發送過去
      //元素類型爲FetchRequest(address, Array((blockId, size)))
      sendRequest(fetchRequests.dequeue())
    }
  }
  • ShuffleBlockFetcherIterator.sendRequest(),成功後調用方法onBlockFetchSuccess(),並把結果放到 results
private[this] def sendRequest(req: FetchRequest) {
    logDebug("Sending request for %d blocks (%s) from %s".format(
      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
    bytesInFlight += req.size

    // so we can look up the size of each blockID
    val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
    val blockIds = req.blocks.map(_._1.toString)

    val address = req.address
    shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
      new BlockFetchingListener {
        override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
          // Only add the buffer to results queue if the iterator is not zombie,
          // i.e. cleanup() has not been called yet.
          if (!isZombie) {
            // Increment the ref count because we need to pass this to a different thread.
            // This needs to be released after use.
            buf.retain()
            results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
            shuffleMetrics.incRemoteBytesRead(buf.size)
            shuffleMetrics.incRemoteBlocksFetched(1)
          }
          logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
        }

        override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
          logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
          results.put(new FailureFetchResult(BlockId(blockId), address, e))
        }
      }
    )
  }
  • ShuffleBlockFetcherIterator.next()方法,返回 (currentResult.blockId, new BufferReleasingInputStream(input, this))
override def next(): (BlockId, InputStream) = {
    numBlocksProcessed += 1

    var result: FetchResult = null
    var input: InputStream = null
    // Take the next fetched result and try to decompress it to detect data corruption,
    // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
    // is also corrupt, so the previous stage could be retried.
    // For local shuffle block, throw FailureFetchResult for the first IOException.
    while (result == null) {
      val startFetchWait = System.currentTimeMillis()
      result = results.take()
      val stopFetchWait = System.currentTimeMillis()
      shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)

      result match {
        case r @ SuccessFetchResult(blockId, address, size, buf) =>
          if (address != blockManager.blockManagerId) {
            shuffleMetrics.incRemoteBytesRead(buf.size)
            shuffleMetrics.incRemoteBlocksFetched(1)
          }
          bytesInFlight -= size

          val in = try {
            buf.createInputStream()
          } catch {
            // The exception could only be throwed by local shuffle block
            case e: IOException =>
              assert(buf.isInstanceOf[FileSegmentManagedBuffer])
              logError("Failed to create input stream from local block", e)
              buf.release()
              throwFetchFailedException(blockId, address, e)
          }

          input = streamWrapper(blockId, in)
          // Only copy the stream if it's wrapped by compression or encryption, also the size of
          // block is small (the decompressed block is smaller than maxBytesInFlight)
          if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
            val originalInput = input
            val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
            try {
              // Decompress the whole block at once to detect any corruption, which could increase
              // the memory usage tne potential increase the chance of OOM.
              // TODO: manage the memory used here, and spill it into disk in case of OOM.
              Utils.copyStream(input, out)
              out.close()
              input = out.toChunkedByteBuffer.toInputStream(dispose = true)
            } catch {
              case e: IOException =>
                buf.release()
                if (buf.isInstanceOf[FileSegmentManagedBuffer]
                  || corruptedBlocks.contains(blockId)) {
                  throwFetchFailedException(blockId, address, e)
                } else {
                  logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
                  corruptedBlocks += blockId
                  fetchRequests += FetchRequest(address, Array((blockId, size)))
                  result = null
                }
            } finally {
              // TODO: release the buf here to free memory earlier
              originalInput.close()
              in.close()
            }
          }

        case FailureFetchResult(blockId, address, e) =>
          throwFetchFailedException(blockId, address, e)
      }

      // Send fetch requests up to maxBytesInFlight
      fetchUpToMaxBytes()
    }

    currentResult = result.asInstanceOf[SuccessFetchResult]
    (currentResult.blockId, new BufferReleasingInputStream(input, this))
  }

endnode

endgit

相關文章
相關標籤/搜索