spark中join有兩種,一種是RDD的join,一種是sql中的join,分別來看:html
org.apache.spark.rdd.PairRDDFunctionssql
/** * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Performs a hash join across the cluster. */ def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = self.withScope { join(other, defaultPartitioner(self, other)) } /** * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. */ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = self.withScope { this.cogroup(other, partitioner).flatMapValues( pair => for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, w) ) } /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner) : RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope { if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { throw new SparkException("HashPartitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) cg.mapValues { case Array(vs, w1s) => (vs.asInstanceOf[Iterable[V]], w1s.asInstanceOf[Iterable[W]]) } }
join操做會返回CoGroupedRDD,CoGroupedRDD構造參數爲rdd數組,即多個須要join的rdd,下面看CoGroupedRDD:apache
org.apache.spark.rdd.CoGroupedRDD數組
class CoGroupedRDD[K: ClassTag]( @transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) { override def getDependencies: Seq[Dependency[_]] = { rdds.map { rdd: RDD[_] => if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) new ShuffleDependency[K, Any, CoGroupCombiner]( rdd.asInstanceOf[RDD[_ <: Product2[K, _]]], part, serializer) } } } override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = { val split = s.asInstanceOf[CoGroupPartition] val numRdds = dependencies.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- dependencies.zipWithIndex) dep match { case oneToOneDependency: OneToOneDependency[Product2[K, Any]] @unchecked => val dependencyPartition = split.narrowDeps(depNum).get.split // Read them from the parent val it = oneToOneDependency.rdd.iterator(dependencyPartition, context) rddIterators += ((it, depNum)) case shuffleDependency: ShuffleDependency[_, _, _] => // Read map outputs of shuffle val it = SparkEnv.get.shuffleManager .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context) .read() rddIterators += ((it, depNum)) } val map = createExternalMap(numRdds) for ((it, depNum) <- rddIterators) { map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes) new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } private def createExternalMap(numRdds: Int) : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = { val createCombiner: (CoGroupValue => CoGroupCombiner) = value => { val newCombiner = Array.fill(numRdds)(new CoGroup) newCombiner(value._2) += value._1 newCombiner } val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner = (combiner, value) => { combiner(value._2) += value._1 combiner } val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner = (combiner1, combiner2) => { var depNum = 0 while (depNum < numRdds) { combiner1(depNum) ++= combiner2(depNum) depNum += 1 } combiner1 } new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner]( createCombiner, mergeValue, mergeCombiners) }
CoGroupedRDD首先將rdds逐個轉化爲dependency,而後將全部的dependency轉化爲rddIterators,最後經過ExternalAppendOnlyMap來實現合併;app
若是rdd須要shuffle,是經過ShuffleManager實現,ShuffleManager實現類爲SortShuffleManager,shuffle過程詳見:http://www.javashuo.com/article/p-ehbhecmo-bo.htmlide
附:spark中dependency結構,即常說的寬依賴、窄依賴:ui
org.apache.spark.Dependencythis
Dependencyspa
NarrowDependencyscala
OneToOneDependency
RangeDependency
ShuffleDependency
區別就是shuffle,不須要shuffle就是NarrowDependency,須要就是ShuffleDependency;
sql中的join有一個選擇策略:
org.apache.spark.sql.execution.SparkStrategies.JoinSelection
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- BroadcastHashJoin -------------------------------------------------------------------- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBuildRight(joinType) && canBroadcast(right) => Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBuildLeft(joinType) && canBroadcast(left) => Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) // --- ShuffledHashJoin --------------------------------------------------------------------- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) && muchSmaller(right, left) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) // --- SortMergeJoin ------------------------------------------------------------ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil ...
其中conf.preferSortMergeJoin
org.apache.spark.sql.internal.SQLConf
val PREFER_SORTMERGEJOIN = SQLConfigBuilder("spark.sql.join.preferSortMergeJoin") .internal() .doc("When true, prefer sort merge join over shuffle hash join.") .booleanConf .createWithDefault(true)
配置spark.sql.join.preferSortMergeJoin,默認爲true,便是否優先使用SortMergeJoin;
能夠看到join實現主要有3種,即BroadcastHashJoinExec、ShuffledHashJoinExec和SortMergeJoinExec,優先級爲
其中BroadcastHashJoinExec和ShuffledHashJoinExec都會用到HashJoin,先看HashJoin:
org.apache.spark.sql.execution.joins.HashJoin
protected def join( streamedIter: Iterator[InternalRow], hashed: HashedRelation, numOutputRows: SQLMetric): Iterator[InternalRow] = { val joinedIter = joinType match { case _: InnerLike => innerJoin(streamedIter, hashed) case LeftOuter | RightOuter => outerJoin(streamedIter, hashed) case LeftSemi => semiJoin(streamedIter, hashed) case LeftAnti => antiJoin(streamedIter, hashed) case j: ExistenceJoin => existenceJoin(streamedIter, hashed) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") } val resultProj = createResultProjection joinedIter.map { r => numOutputRows += 1 resultProj(r) } } private def innerJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { val joinRow = new JoinedRow val joinKeys = streamSideKeyGenerator() streamIter.flatMap { srow => joinRow.withLeft(srow) val matches = hashedRelation.get(joinKeys(srow)) if (matches != null) { matches.map(joinRow.withRight(_)).filter(boundCondition) } else { Seq.empty } } }
這裏只貼出內關聯,即innerJoin,代碼比較簡單,注意這裏是內存操做,會在單個partition內部進行;
org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) join(streamedIter, hashed, numOutputRows) } }
這裏會將buildPlan廣播出去,而後在streamedPlan上經過mapPartitions在1個分區內部進行join,join方法見HashJoin;
org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter) join(streamIter, hashed, numOutputRows) } }
join過程爲先將兩個rdd(streamedPlan和buildPlan)進行zipPartitions,而後在1個partition內部join,join方法見HashJoin;
org.apache.spark.sql.execution.joins.SortMergeJoinExec
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => newPredicate(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true } } // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) joinType match { case _: InnerLike => new RowIterator { private[this] var currentLeftRow: InternalRow = _ private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ private[this] var currentMatchIdx: Int = -1 private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter) ) private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow currentMatchIdx = 0 } override def advanceNext(): Boolean = { while (currentMatchIdx >= 0) { if (currentMatchIdx == currentRightMatches.length) { if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow currentMatchIdx = 0 } else { currentRightMatches = null currentLeftRow = null currentMatchIdx = -1 return false } } joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) currentMatchIdx += 1 if (boundCondition(joinRow)) { numOutputRows += 1 return true } } false } override def getRow: InternalRow = resultProj(joinRow) }.toScala ...
和ShuffledHashJoinExec同樣,一樣先zipPartitions,而後在1個partition內部根據joinType返回不一樣的RowIterator實現類,上邊代碼包含內關聯實現,大部分工做經過SortMergeJoinScanner實現
org.apache.spark.sql.execution.joins.SortMergeJoinScanner
final def findNextInnerJoinRows(): Boolean = { while (advancedStreamed() && streamedRowKey.anyNull) { // Advance the streamed side of the join until we find the next row whose join key contains // no nulls or we hit the end of the streamed iterator. } if (streamedRow == null) { // We have consumed the entire streamed iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { // The new streamed row has the same join key as the previous row, so return the same matches. true } else if (bufferedRow == null) { // The streamed row's join key does not match the current batch of buffered rows and there are // no more rows to read from the buffered iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else { // Advance both the streamed and buffered iterators to find the next pair of matching rows. var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) do { if (streamedRowKey.anyNull) { advancedStreamed() } else { assert(!bufferedRowKey.anyNull) comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() else if (comp < 0) advancedStreamed() } } while (streamedRow != null && bufferedRow != null && comp != 0) if (streamedRow == null || bufferedRow == null) { // We have either hit the end of one of the iterators, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else { // The streamed row's join key matches the current buffered row's join, so walk through the // buffered iterator to buffer the rest of the matching rows. assert(comp == 0) bufferMatchingRows() true } } } /** * Advance the streamed iterator and compute the new row's join key. * @return true if the streamed iterator returned a row and false otherwise. */ private def advancedStreamed(): Boolean = { if (streamedIter.advanceNext()) { streamedRow = streamedIter.getRow streamedRowKey = streamedKeyGenerator(streamedRow) true } else { streamedRow = null streamedRowKey = null false } } /** * Advance the buffered iterator until we find a row with join key that does not contain nulls. * @return true if the buffered iterator returned a row and false otherwise. */ private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = { var foundRow: Boolean = false while (!foundRow && bufferedIter.advanceNext()) { bufferedRow = bufferedIter.getRow bufferedRowKey = bufferedKeyGenerator(bufferedRow) foundRow = !bufferedRowKey.anyNull } if (!foundRow) { bufferedRow = null bufferedRowKey = null false } else { true } } /** * Called when the streamed and buffered join keys match in order to buffer the matching rows. */ private def bufferMatchingRows(): Unit = { assert(streamedRowKey != null) assert(!streamedRowKey.anyNull) assert(bufferedRowKey != null) assert(!bufferedRowKey.anyNull) assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) // This join key may have been produced by a mutable projection, so we need to make a copy: matchJoinKey = streamedRowKey.copy() bufferedMatches.clear() do { bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them advancedBufferedToRowWithNullFreeJoinKey() } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) }
能夠看到過程和二路歸併排序Binary Merge Sort差很少;
附:RowIterator是一個抽象類,本質是一個接口,是一個常見的Iterator定義,以下:
org.apache.spark.sql.execution.RowIterator
abstract class RowIterator { /** * Advance this iterator by a single row. Returns `false` if this iterator has no more rows * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling * [[getRow]]. */ def advanceNext(): Boolean /** * Retrieve the row from this iterator. This method is idempotent. It is illegal to call this * method after [[advanceNext()]] has returned `false`. */ def getRow: InternalRow /** * Convert this RowIterator into a [[scala.collection.Iterator]]. */ def toScala: Iterator[InternalRow] = new RowIteratorToScala(this) }