Spark 實現本身的RDD,讓代碼更優雅

你是否在最初書寫spark的代碼時老是使用object 是否在爲代碼的重複而憂心,接下來的博客中,我會專一於spark代碼簡潔性。java

1,什麼事RDD,官網上有很全面的解釋,在此再也不贅述,不過咱們須要從代碼層面上理解什麼事RDD,若是他是一個類,他又有哪些重要的屬性和方法,如今列出如下幾點:mysql

    1)partitions():Get the array of partitions of this RDD, taking into account whether the
RDD is checkpointed or not. Partition是一個特質,分佈在每個excutor上的分區,都會有一個Partition實現類去作惟一標識。sql

    2)iterator():Internal method to this RDD; will read from cache if applicable, or otherwise compute it. This should not be called by users directly, but is available for implementors of custom subclasses of RDD. 這是一個RDD的迭代器,傳入的參數是Partition和TaskContext,這樣就能夠在每個Partition上執行相應的邏輯了。數據庫

    3)dependencies():Get the list of dependencies of this RDD,在1.6中,Dependency共有以下幾個繼承類,後續博文會詳解它,感興趣的讀者能夠直接閱讀源碼進一步瞭解apache

            

    4)partitioner():此函數返回一個Option[Partitioner],若是RDD不是key-value pair RDD類型的數據,那麼爲None,咱們和以本身實現這個抽象類。當時看到這裏,我就在想爲何不能實現一個特質,而要用app

抽象類,我的理解這是屬於面向對象的東西了,類是實體的抽象愛,而接口則定義一些行爲。分佈式

    5)preferredLocations():Optionally overridden by subclasses to specify placement preferences.ide

 

下面咱們本身實現一個和Mysql交互的RDD,只涉及到上面說的部分函數,固然在生產環境中不建議這樣作,除非你本身想把本身的mysql搞掛,此處只是演示,對於像Hbase之類的分佈式數據庫,邏輯相似。函數

package com.hypers.rdd

import java.sql.{Connection, ResultSet}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}

import scala.reflect.ClassTag

//TODO 去重
class HFAJdbcRDD[T: ClassTag]
    (sc: SparkContext,
     connection: () => Connection, //method
     sql: String,
     numPartittions: Int,
     mapRow: (ResultSet) => T
) extends RDD[T](sc, Nil) with Logging {

    /**
      * 如果這個Rdd是有父RDD 那麼 compute通常會調用到iterator方法 將taskContext傳遞出去
      * @param thePart
      * @param context
      * @return
      */
    @DeveloperApi
    override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new Iterator[T] {

        val part = thePart.asInstanceOf[HFAJdbcPartition]
        val conn = connection()
        //若是直接執行sql會使數據重複,所以此處使用分頁
        val stmt = conn.prepareStatement(String.format("%s limit %s,1",sql,thePart.index.toString), ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
        logInfo("Get sql data size is " + stmt.getFetchSize)
        val rs: ResultSet = stmt.executeQuery()

        override def hasNext: Boolean = {
            if(rs.next()){
                true
            }else{
                conn.close()
                false
            }
        }

        override def next(): T = {
            mapRow(rs)
        }
    }


    /**
      * 將一些信息傳遞到compute方法 例如sql limit 的參數
      * @return
      */
    override protected def getPartitions: Array[Partition] = {
        (0 until numPartittions).map { inx =>
            new HFAJdbcPartition(inx)
        }.toArray
    }
}

private class HFAJdbcPartition(inx: Int) extends Partition {
    override def index: Int = inx
}

 

package com.hypers.rdd.execute

import java.sql.{DriverManager, ResultSet}

import com.hypers.commons.spark.BaseJob
import com.hypers.rdd.HFAJdbcRDD

//BaseJob裏面作了sc的初始化,在此不作演示,您也能夠本身new出sparkContext
object HFAJdbcTest extends BaseJob {

    def main(args: Array[String]) {
        HFAJdbcTest(args)
    }

    override def apply(args: Array[String]): Unit = {

        val jdbcRdd = new HFAJdbcRDD[Tuple2[Int, String]](sc,
            getConnection,
            "select id,name from user where id<10",
            3,
            reseultHandler
        )

        logger.info("count is " + jdbcRdd.count())
        logger.info("count keys " + jdbcRdd.keys.collect().toList)

    }

    def getConnection() = {
        Class.forName("com.mysql.jdbc.Driver").newInstance()
        DriverManager.getConnection("jdbc:mysql://localhost:3306/db", "root", "123456")
    }

    def reseultHandler(rs: ResultSet): Tuple2[Int, String] = {
        rs.getInt("id") -> rs.getString("name")

    }
}
相關文章
相關標籤/搜索