基於 Spark MLlib 和 Spark Streaming 實現準實時分類

環境版本: ·Spark 2.0 ·Scala 2.11.8  在網上搜索 Spark MLlib 和 Spark Streaming 結合的例子幾乎沒有,我很疑惑,難道實現準實時預測有別的更合理的方式?望大佬在評論區指出。java

環境版本: ·Spark 2.0  ·Scala 2.11.8複製代碼

  在網上搜索 Spark MLlib 和 Spark Streaming 結合的例子幾乎沒有,我很疑惑,難道實現準實時預測有別的更合理的方式?望大佬在評論區指出。本篇博客思路很簡單,使用 Spark MLlib 訓練並保存模型,而後編寫 Spark Streaming 程序讀取並使用模型。需注意的是,在使用 Spark MLlib 以前我使用了 python 查看分析數據、清洗數據、特徵工程、構造數據集、訓練模型等等,且在本篇中直接使用了 python 構造的數據集。python

1 訓練並保存模型

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.rdd.RDD


object RandomForestM {
  def main(args: Array[String]) {
    val sparkConf = new SparkConf()
      
      .setMaster("local[*]")
      .setAppName("rf")
    val sc = new SparkContext(sparkConf)
    
    val rawData = sc.textFile("hdfs://xx:8020/model/data/xx.csv")
    val data = rawData.map { line =>
      val values = line.split(",").map(_.toDouble)
      
      
      val feature = Vectors.dense(values.init)
      val label = values.last
      LabeledPoint(label, feature)
    }
    
    
    
    val Array(trainData, cvData, testData) = data.randomSplit(Array(0.8, 0.1, 0.1))
    trainData.cache()
    cvData.cache()
    testData.cache()

    
    val model = RandomForest.trainClassifier(trainData, 2, Map[Int, Int](), 20, "auto", "gini", 4, 32)
    val metrics = getMetrics(model, cvData)

    
    println(metrics.confusionMatrix)

    println(metrics.accuracy) 

    
    (0 until 2).map(target => (metrics.precision(target), metrics.recall(target))).foreach(println)
    
	model.save(sc, "hdfs://xx:8020/model/xxModel")

  }

  
  def getMetrics(model: RandomForestModel, data: RDD[LabeledPoint]): MulticlassMetrics = {
    
    val predictionsAndLables = data.map { d =>
      (model.predict(d.features), d.label)
    }
    
    new MulticlassMetrics(predictionsAndLables)
  }

  
  def getBestParam(trainData: RDD[LabeledPoint], cvData: RDD[LabeledPoint]): Unit = {
    val evaluations = for (impurity <- Array("gini", "entropy");
                           depth <- Array(1, 20);
                           bins <- Array(10, 300)) yield {
      val model = RandomForest.trainClassifier(trainData, 2, Map[Int, Int](), 20, "auto", impurity, depth, bins)
      val metrics = getMetrics(model, cvData)
      ((impurity, depth, bins), metrics.accuracy)
    }
    evaluations.sortBy(_._2).reverse.foreach(println)
  }

}複製代碼

2 讀取並使用模型

import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.kafka.common.serialization.StringDeserializer
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.dstream.InputDStream
import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe
import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent
import org.apache.spark.streaming.kafka010.{HasOffsetRanges, KafkaUtils, OffsetRange}


object ModelTest {
  private val brokers = "xx1:6667,xx2:6667,xx3:6667"

  def main(args: Array[String]): Unit = {

    val sparkConf = new SparkConf()
      
      .setMaster("local[*]")
      .setAppName("ModelTest")
    sparkConf.set("spark.sql.warehouse.dir","file:///")  
    val sc = new SparkContext(sparkConf)
    
    val rfModel = RandomForestModel.load(sc, "hdfs://xx:8020/model/xxModel")
    
    val ssc = new StreamingContext(sc, Seconds(6))
    val topics = Array("xx1", "xx2")
    val kafkaParams = Map[String, Object](
      "bootstrap.servers" -> brokers,
      "key.deserializer" -> classOf[StringDeserializer],
      "value.deserializer" -> classOf[StringDeserializer],
      "group.id" -> "hqc",
      "auto.offset.reset" -> "latest",
      "enable.auto.commit" -> (false: java.lang.Boolean)
    )
    val messages: InputDStream[ConsumerRecord[String, String]] = KafkaUtils.createDirectStream[String, String](
      ssc,
      PreferConsistent,
      Subscribe[String, String](topics, kafkaParams)
    )

    
    messages.foreachRDD(rdd => {
      val offsetRanges: Array[OffsetRange] = rdd.asInstanceOf[HasOffsetRanges].offsetRanges

      rdd.foreach((msg: ConsumerRecord[String, String]) => {
        val o = offsetRanges(TaskContext.get.partitionId)
        

        val topic: String = o.topic
        
        topic match {
          case "xx1" =>
            val line = KxxDataClean(msg.value)
            
            if (line != "") {
              val values = line.split(",").map(_.toDouble)
              val feature = Vectors.dense(values)
              
              val preLabel = rfModel.predict(feature)
              println(preLabel)
            }

          case "xx2" =>


        }
        
      })
    })

    ssc.start()
    ssc.awaitTermination()
  }複製代碼

pom.xmlmysql

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.xx</groupId>
    <artifactId>spark-xx-model</artifactId>
    <version>1.0-SNAPSHOT</version>
    <properties>
        <spark.version>2.0.0</spark.version>
    </properties>

    <dependencies>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>${spark.version}</version>
            <scope>compile</scope>
        </dependency>

        <!--spark streaming + kafka-->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming_2.11</artifactId>
            <version>${spark.version}</version>
            <scope>compile</scope>
        </dependency>

        <dependency>
            <groupId>org.apache.kafka</groupId>
            <artifactId>kafka_2.11</artifactId>
            <version>0.10.0.0</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming-kafka-0-10_2.11</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <!--mysql-->
        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>5.1.39</version>
        </dependency>

        <!--日誌-->
        <dependency>
            <groupId>com.typesafe.scala-logging</groupId>
            <artifactId>scala-logging_2.11</artifactId>
            <version>3.7.2</version>
        </dependency>
        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>1.2.3</version>
        </dependency>

    </dependencies>

    <build>
        <plugins>
            <plugin>
                <artifactId>maven-assembly-plugin</artifactId>
                <version>2.3</version>
                <configuration>
                    <classifier>dist</classifier>
                    <appendAssemblyId>true</appendAssemblyId>
                    <descriptorRefs>
                        <descriptor>jar-with-dependencies</descriptor>
                    </descriptorRefs>
                </configuration>
                <executions>
                    <execution>
                        <id>make-assembly</id>
                        <phase>package</phase>
                        <goals>
                            <goal>single</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>


</project>複製代碼
相關文章
相關標籤/搜索