Spark partitionBy

partitionBy 從新分區, repartition默認採用HashPartitioner分區,本身設計合理的分區方法(好比數量比較大的key 加個隨機數 隨機分到更多的分區, 這樣處理數據傾斜更完全一些)

/** * An object that defines how the elements in a key-value pair RDD are partitioned by key. * Maps each key to a partition ID, from 0 to `numPartitions - 1`. */
abstract class Partitioner extends Serializable { def numPartitions: Int def getPartition(key: Any): Int }
import org.apache.spark.HashPartitioner import org.apache.spark.sql.SparkSession 
object PartitionBy_Test { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local").appName(this.getClass.getSimpleName).getOrCreate() val rdd = spark.sparkContext.parallelize(Array(("a", 1), ("a", 2), ("b", 1), ("b", 3), (("c", 1)), (("e", 1))), 2) val result = rdd.mapPartitionsWithIndex { (partIdx, iter) => { val part_map = scala.collection.mutable.Map[String, List[(String, Int)]]() while (iter.hasNext) { val part_name = "part_" + partIdx var elem = if (part_map.contains(part_name)) { var elems = part_map(part_name) elems ::= elem part_map(part_name) = elems } else { part_map(part_name) = List[(String, Int)] { elem } } } part_map.iterator } }.collect result.foreach(x => println(x._1 + ":" + x._2.toString())) } }

這裏的分區方法能夠選擇, 默認的分區就是HashPartition分區,
注意若是屢次使用該RDD或者進行join操做, 分區後peresist持久化操做sql

/** * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using * Java's `Object.hashCode`. * * Java arrays have hashCodes that are based on the arrays' identities rather than their contents, * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will * produce an unexpected or incorrect result. */
class HashPartitioner(partitions: Int) extends Partitioner { require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.") def numPartitions: Int = partitions def getPartition(key: Any): Int = key match { case null => 0
    case _ => Utils.nonNegativeMod(key.hashCode, numPartitions) } override def equals(other: Any): Boolean = other match { case h: HashPartitioner => h.numPartitions == numPartitions case _ =>
      false } override def hashCode: Int = numPartitions }

範圍分區 RangePartitioner :先鍵值排序, 肯定樣本大小,採樣後不放回整體的隨機採樣方法, 分配鍵值的分區,經過樣本採樣避免數據傾斜。apache

class RangePartitioner[K : Ordering : ClassTag, V]( partitions: Int, rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true, val samplePointsPerPartitionHint: Int = 20) extends Partitioner { // A constructor declared in order to maintain backward compatibility for Java, when we add the // 4th constructor parameter samplePointsPerPartitionHint. See SPARK-22160. // This is added to make sure from a bytecode point of view, there is still a 3-arg ctor.
  def this(partitions: Int, rdd: RDD[_ <: Product2[K, V]], ascending: Boolean) = { this(partitions, rdd, ascending, samplePointsPerPartitionHint = 20) } // We allow partitions = 0, which happens when sorting an empty RDD under the default settings.
  require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") require(samplePointsPerPartitionHint > 0, s"Sample points per partition must be greater than 0 but found $samplePointsPerPartitionHint") private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions
  private var rangeBounds: Array[K] = { if (partitions <= 1) { Array.empty } else { // This is the sample size we need to have roughly balanced output partitions, capped at 1M. // Cast to double to avoid overflowing ints or longs
      val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6) // Assume the input partitions are roughly balanced and over-sample a little bit.
      val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt val (numItems, sketched) = RangePartitioner.sketch(, sampleSizePerPartition) if (numItems == 0L) { Array.empty } else { // If a partition contains much more than the average number of items, we re-sample from it // to ensure that enough items are collected from that partition.
        val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0) val candidates = ArrayBuffer.empty[(K, Float)] val imbalancedPartitions = mutable.Set.empty[Int] sketched.foreach { case (idx, n, sample) =>
          if (fraction * n > sampleSizePerPartition) { imbalancedPartitions += idx } else { // The weight is 1 over the sampling probability.
            val weight = (n.toDouble / sample.length).toFloat for (key <- sample) { candidates += ((key, weight)) } } } if (imbalancedPartitions.nonEmpty) { // Re-sample imbalanced partitions with the desired sampling probability.
          val imbalanced = new PartitionPruningRDD(, imbalancedPartitions.contains) val seed = byteswap32( - 1) val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect() val weight = (1.0 / fraction).toFloat candidates ++= => (x, weight)) } RangePartitioner.determineBounds(candidates, math.min(partitions, candidates.size)) } } } def numPartitions: Int = rangeBounds.length + 1

  private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] def getPartition(key: Any): Int = { val k = key.asInstanceOf[K] var partition = 0
    if (rangeBounds.length <= 128) { // If we have less than 128 partitions naive search
      while (partition < rangeBounds.length &&, rangeBounds(partition))) { partition += 1 } } else { // Determine which binary search method to use only once.
      partition = binarySearch(rangeBounds, k) // binarySearch either returns the match location or -[insertion point]-1
      if (partition < 0) { partition = -partition-1 } if (partition > rangeBounds.length) { partition = rangeBounds.length } } if (ascending) { partition } else { rangeBounds.length - partition } } override def equals(other: Any): Boolean = other match { case r: RangePartitioner[_, _] => r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending case _ =>
      false } override def hashCode(): Int = { val prime = 31
    var result = 1
    var i = 0
    while (i < rangeBounds.length) { result = prime * result + rangeBounds(i).hashCode i += 1 } result = prime * result + ascending.hashCode result } @throws(classOf[IOException]) private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { val sfactory = SparkEnv.get.serializer sfactory match { case js: JavaSerializer => out.defaultWriteObject() case _ =>
        out.writeBoolean(ascending) out.writeObject(ordering) out.writeObject(binarySearch) val ser = sfactory.newInstance() Utils.serializeViaNestedStream(out, ser) { stream => stream.writeObject(scala.reflect.classTag[Array[K]]) stream.writeObject(rangeBounds) } } } @throws(classOf[IOException]) private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { val sfactory = SparkEnv.get.serializer sfactory match { case js: JavaSerializer => in.defaultReadObject() case _ => ascending = in.readBoolean() ordering = in.readObject().asInstanceOf[Ordering[K]] binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int] val ser = sfactory.newInstance() Utils.deserializeViaNestedStream(in, ser) { ds =>
          implicit val classTag = ds.readObject[ClassTag[Array[K]]]() rangeBounds = ds.readObject[Array[K]]() } } } }

自定義分區函數 本身根據業務數據減緩數據傾斜問題:
要實現自定義的分區器,你須要繼承 org.apache.spark.Partitioner 類並實現下面三個方法app

  • numPartitions: Int:返回建立出來的分區數。
  • getPartition(key: Any): Int:返回給定鍵的分區編號( 0 到 numPartitions-1)。
class UsridPartitioner(numParts:Int) extends Partitioner{ //覆蓋分區數
  override def numPartitions: Int = numParts //覆蓋分區號獲取函數
  override def getPartition(key: Any): Int = { if(key.toString == "A") key.toString.toInt%10
     else: key.toString.toInt%5 } }