本套系列博客從真實商業環境抽取案例進行總結和分享,並給出Spark源碼解讀及商業實戰指導,請持續關注本套博客。版權聲明:本套Spark源碼解讀及商業實戰歸做者(秦凱新)全部,禁止轉載,歡迎學習。node
一張圖我已經用過屢次了,不要見怪,由於畢竟都是一個主題,有關shuffle的。英文註釋已經很詳細了,這裏簡單介紹一下:算法
官方英文介紹以下:apache
* Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on the
* driver and on each executor, based on the spark.shuffle.manager setting. The driver
* registers shuffles with it, and executors (or tasks running locally in the driver) can ask * to read and write data.
* NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and
* boolean isDriver as parameters.
複製代碼
reduce Task 最最核心的方法就是BlockStoreShuffleReader幹嗎的呢?主要從MapTask輸出的正式的惟一的Block文件中讀取由起始分區和結束分區指定範圍內的數據。開始以前,咱們重點介紹一下成員變量,同時開啓一段英文:緩存
* Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by
* requesting them from other nodes' block stores。
複製代碼
那麼要想使用,構造器裏面須要封裝什麼呢?網絡
private[spark] class BlockStoreShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
context: TaskContext,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C] with Logging
複製代碼
dep :BaseShuffleHandle 經過樣例類傳入的handle.dependency,也即ShuffleDependency架構
read() 方法app
mapOutputTracker : 即SparkEnv的子組件MapOuputTracker框架
ShuffleBlockFetcherIterator:用於獲取多個Block的迭代器,說白了就是調用 mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)的返回值 Seq[(BlockManagerId, Seq[(BlockId, Long)])],傳入地址序列後,經由該方法獲取全部請求的數據迭代器。dom
getMapSizesByExecutorId:async
* Called from executors to get the server URIs and output sizes for each shuffle block that
* needs to be read from a given range of map output partitions (startPartition is included but
* endPartition is excluded from the range).
複製代碼
若是指定了dep.mapSideCombine,就會在ExternalOnlyMap中進行聚合,注意這個可不是AppendOnlyMap。ExternalOnlyMap則繼承SizeTrackingAppendOnlyMap,因此沒有排序輸出迭代器的東西,只有聚合和緩衝的功能
wrappedStreams 表示獲取的MapTask的Block數據 override def read(): Iterator[Product2[K, C]] = {
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
val serializerInstance = dep.serializer.newInstance()
// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
-----------------------------------神來之筆(聚合或緩衝)-----------------------------------
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
--------------------------------------------------------------------------------------------
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
===========================若是須要全局排序調用sorter.insertAll=========================
sorter.insertAll(aggregatedIter) <= 構造器沒有聚合器傳入,因此使用PartitionedPairBuffer作緩衝
========================================================================================
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
====================若是須要全局排序,直接給出排序迭代器=========================
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) <=神來之筆,sorter.iterator直接給出排序迭代器=
=========================================================================================
case None =>
====================若是不須要全局排序,直接給出緩衝區數據迭代器=========================
aggregatedIter
=========================================================================================
}
}
複製代碼
}
構造器變量(aggregator,ordering):
private[spark] class ExternalSorter[K, V, C](
context: TaskContext,
aggregator: Option[Aggregator[K, V, C]] = None,
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
serializer: Serializer = SparkEnv.get.serializer)
extends Spillable[WritablePartitionedPairCollection[K, C]](context.taskMemoryManager())
with Logging
複製代碼
insertAll 方法,發現 aggregator.isDefined若沒有定義,則會使用PartitionedPairBuffer作緩衝,另外注意的是插入操做最多隻會聚合。因此插入會很快,由於沒有排序。若要全局排序,就要調用iterator讀數據時纔會全局排序並給出迭代器。內部爲 groupByPartition(destructiveIterator( collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined
if (shouldCombine) {
// Combine values in-memory first using our AppendOnlyMap
val mergeValue = aggregator.get.mergeValue
val createCombiner = aggregator.get.createCombiner
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
while (records.hasNext) {
addElementsRead()
kv = records.next()
map.changeValue((getPartition(kv._1), kv._1), update)
maybeSpillCollection(usingMap = true)
}
} else {
// Stick values into our buffer
while (records.hasNext) {
addElementsRead()
val kv = records.next()
buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)
}
}
}
複製代碼
sorter.iterator :進一步調用ExternalSorter.partitionedIterator方法
def iterator: Iterator[Product2[K, C]] = {
isShuffleSort = false
partitionedIterator.flatMap(pair => pair._2)
}
複製代碼
ExternalSorter.partitionedIterator:根據ordering.isDefined來最終調用ExternalSorter.groupByPartition, 最終實現了按照分區讀數據時,按照Key排序輸出。
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
val usingMap = aggregator.isDefined
val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
if (spills.isEmpty) {
// Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
// we don't even need to sort by anything other than partition ID
if (!ordering.isDefined) {
------------------------------神來之筆(無排序迭代器)----------------------------------
// The user hasn't requested sorted keys, so only sort by partition ID, not key
groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)
))
} else {
------------------------------神來之筆(排序並輸出迭代器)--------------------------------
// We do need to sort by both partition ID and key
groupByPartition(destructiveIterator(
collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
}
} else {
// Merge spilled and in-memory data
merge(spills, destructiveIterator(
collection.partitionedDestructiveSortedIterator(comparator)))
}
}
複製代碼
ShuffleBlockFetcherIterator會經過splitLocalRemoteBlocks劃分數據的讀取策略:若是在本地有,那麼能夠直接從BlockManager中獲取數據;若是須要從其餘的節點上獲取,那麼須要走網絡。
localBlocks:本地 remoteBlocks:遠程 results:請求成功或失敗 SuccessFetchResult
/** Local blocks to fetch, excluding zero-sized blocks. */
localBlocks:private[this] val
= new ArrayBuffer[BlockId]()
/** Remote blocks to fetch, excluding zero-sized blocks. */
private[this] val remoteBlocks = new HashSet[BlockId]()
/**
* A queue to hold our results. This turns the asynchronous model provided by
* [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
*/
private[this] val results = new LinkedBlockingQueue[FetchResult]
複製代碼
能夠看到核心方法有:
splitLocalRemoteBlocks:劃分本地和遠程讀取Block請求 ,本地的放在localBlocks
fetchUpToMaxBytes:發送sendRequest請求,遠程拉取數據。
fetchLocalBlocks:調用本地的BlockManager來讀取數據。
private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
context.addTaskCompletionListener(_ => cleanup())
// Split local and remote blocks.
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
assert ((0 == reqsInFlight) == (0 == bytesInFlight),
"expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
// Send out initial requests for blocks, up to our maxBytesInFlight
fetchUpToMaxBytes()
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
複製代碼
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize
+ ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress)
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
// Tracks total number of blocks (including zero sized blocks)
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
if (address.executorId == blockManager.blockManagerId.executorId) {
// Filter out zero-sized blocks
localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
numBlocksToFetch += localBlocks.size
} else {
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
// Skip empty blocks
if (size > 0) {
curBlocks += ((blockId, size))
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
} else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
if (curRequestSize >= targetRequestSize ||
curBlocks.size >= maxBlocksInFlightPerAddress) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
logDebug(s"Creating fetch request of $curRequestSize at $address "
+ s"with ${curBlocks.size} blocks")
curBlocks = new ArrayBuffer[(BlockId, Long)]
curRequestSize = 0
}
}
// Add in the final request
if (curBlocks.nonEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
remoteRequests
}
複製代碼
private def fetchUpToMaxBytes(): Unit = {
// Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
// immediately, defer the request until the next time it can be processed.
// Process any outstanding deferred fetch requests if possible.
if (deferredFetchRequests.nonEmpty) {
for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
while (isRemoteBlockFetchable(defReqQueue) &&
!isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
val request = defReqQueue.dequeue()
logDebug(s"Processing deferred fetch request for $remoteAddress with "
+ s"${request.blocks.length} blocks")
send(remoteAddress, request)
if (defReqQueue.isEmpty) {
deferredFetchRequests -= remoteAddress
}
}
}
}
複製代碼
發送遠程讀取讀取shuffleClient.fetchBlocks請求讀取數據
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
reqsInFlight += 1
// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
val blockIds = req.blocks.map(_._1.toString)
val address = req.address
val blockFetchingListener = new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
// i.e. cleanup() has not been called yet.
ShuffleBlockFetcherIterator.this.synchronized {
if (!isZombie) {
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
remainingBlocks -= blockId
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
remainingBlocks.isEmpty))
logDebug("remainingBlocks: " + remainingBlocks)
}
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
// Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
// already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
// the data and write it to file directly.
if (req.size > maxReqSizeShuffleToMem) {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, this)
} else {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, null)
}
}
複製代碼
private[this] def fetchLocalBlocks() {
val iter = localBlocks.iterator
while (iter.hasNext) {
val blockId = iter.next()
try {
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
return
}
}
}
複製代碼
秦凱新 於深圳