使用基於Apache Spark的隨機森林方法預測貸款風險

 

轉載 2016年07月20日 14:19:56
http://blog.csdn.net/mr__fang/article/details/51967852

原文:Predicting Loan Credit Risk using Apache Spark Machine Learning Random Forests 
做者:Carol McDonald,MapR解決方案架構師 
翻譯:KK4SBB 
責編:周建丁(zhoujd@csdn.net)javascript

在本文中,我將向你們介紹如何使用Apache Spark的spark.ml庫中的隨機森林算法來對銀行信用貸款的風險作分類預測。Spark的spark.ml庫基於DataFrame,它提供了大量的接口,幫助用戶建立和調優機器學習工做流。結合dataframe使用spark.ml,可以實現模型的智能優化,從而提高模型效果。html

分類算法

分類算法是一類監督式機器學習算法,它根據已知標籤的樣本(如已經明確交易是否存在欺詐)來預測其它樣本所屬的類別(如是否屬於欺詐性的交易)。分類問題須要一個已經標記過的數據集和預先設計好的特徵,而後基於這些信息來學習給新樣本打標籤。所謂的特徵便是一些「是與否」的問題。標籤就是這些問題的答案。在下面這個例子裏,若是某個動物的行走姿態、游泳姿式和叫聲都像鴨子,那麼就給它打上「鴨子」的標籤。java

咱們來看一個銀行信貸的信用風險例子:git

  • 咱們須要預測什麼? 
    • 某我的是否會按時還款
    • 這就是標籤:此人的信用度
  • 你用來預測的「是與否」問題或者屬性是什麼? 
    • 申請人的基本信息和社會身份信息:職業,年齡,存款儲蓄,婚姻狀態等等……
    • 這些就是特徵,用來構建一個分類模型,你從中提取出對分類有幫助的特徵信息。

決策樹模型

決策樹是一種基於輸入特徵來預測類別或是標籤的分類模型。決策樹的工做原理是這樣的,它在每一個節點都須要計算特徵在該節點的表達式值,而後基於運算結果選擇一個分支通往下一個節點。下圖展現了一種用來預測信用風險的決策樹模型。每一個決策問題就是模型的一個節點,「是」或者「否」的答案是通往子節點的分支。github

  • 問題1:帳戶餘額是否大於200元? 
    • 問題2:當前就任時間是否超過1年? 
      • 不可信賴

圖片描述

隨機森林模型

融合學習算法結合了多個機器學習的算法,從而獲得了效果更好的模型。隨機森林是分類和迴歸問題中一類經常使用的融合學習方法。此算法基於訓練數據的不一樣子集構建多棵決策樹,組合成一個新的模型。預測結果是全部決策樹輸出的組合,這樣可以減小波動,而且提升預測的準確度。對於隨機森林分類模型,每棵樹的預測結果都視爲一張投票。得到投票數最多的類別就是預測的類別。算法

圖片描述

基於Spark機器學習工具來分析信用風險問題

咱們使用德國人信用度數據集,它按照一系列特徵屬性將人分爲信用風險好和壞兩類。咱們能夠得到每一個銀行貸款申請者的如下信息:sql

圖片描述

存放德國人信用數據的csv文件格式以下:shell

1,1,18,4,2,1049,1,2,4,2,1,4,2,21,3,1,1,3,1,1,1
1,1,9,4,0,2799,1,3,2,3,1,2,1,36,3,1,2,3,2,1,1
1,2,12,2,9,841,2,4,2,2,1,4,1,23,3,1,1,2,1,1,1

在這個背景下,咱們會構建一個由決策樹組成的隨機森林模型來預測是否守信用的標籤/類別,基於如下特徵:apache

  • 標籤 -> 守信用或者不守信用(1或者0)
  • 特徵 -> {存款餘額,信用歷史,貸款目的等等}

軟件

本教程將使用Spark 1.6.1數組

按照教程指示,登陸MapR沙箱,用戶名爲user01,密碼爲mapr。將樣本數據文件複製到你的沙箱主目錄下/user/user01 using scp。(注意,你可能須要先更新Spark的版本)打開spark shell:

$spark-shell --master local[1]

加載並解析csv數據文件

首先,咱們須要引入機器學習相關的包。

import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.VectorAssembler
import sqlContext.implicits._
import sqlContext._
import org.apache.spark.ml.tuning.{ ParamGridBuilder, CrossValidator }
import org.apache.spark.ml.{ Pipeline, PipelineStage }

咱們用一個Scala的case類來定義Credit的屬性,對應於csv文件中的一行。

// define the Credit Schema case class Credit( creditability: Double, balance: Double, duration: Double, history: Double, purpose: Double, amount: Double, savings: Double, employment: Double, instPercent: Double, sexMarried: Double, guarantors: Double, residenceDuration: Double, assets: Double, age: Double, concCredit: Double, apartment: Double, credits: Double, occupation: Double, dependents: Double, hasPhone: Double, foreign: Double )

下面的函數解析一行數據文件,將值存入Credit類中。類別的索引值減去了1,所以起始索引值爲0.

 // function to create a Credit class from an Array of Double def parseCredit(line: Array[Double]): Credit = { Credit( line(0), line(1) - 1, line(2), line(3), line(4) , line(5), line(6) - 1, line(7) - 1, line(8), line(9) - 1, line(10) - 1, line(11) - 1, line(12) - 1, line(13), line(14) - 1, line(15) - 1, line(16) - 1, line(17) - 1, line(18) - 1, line(19) - 1, line(20) - 1 ) }  // function to transform an RDD of Strings into an RDD of Double def parseRDD(rdd: RDD[String]): RDD[Array[Double]] = { rdd.map(_.split(",")).map(_.map(_.toDouble)) }

接下去,咱們導入germancredit.csv文件中的數據,存爲一個String類型的RDD。而後咱們對RDD作map操做,將RDD中的每一個字符串通過ParseRDDR函數的映射,轉換爲一個Double類型的數組。緊接着是另外一個map操做,使用ParseCredit函數,將每一個Double類型的RDD轉換爲Credit對象。toDF()函數將Array[[Credit]]類型的RDD轉爲一個Credit類的Dataframe。

// load the data into a RDD val creditDF= parseRDD(sc.textFile("germancredit.csv")).map(parseCredit).toDF().cache() creditDF.registerTempTable("credit") DataFrame的printSchema()函數將各個字段含義以樹狀的形式打印到控制檯輸出。 // Return the schema of this DataFrame creditDF.printSchema root |-- creditability: double (nullable = false) |-- balance: double (nullable = false) |-- duration: double (nullable = false) |-- history: double (nullable = false) |-- purpose: double (nullable = false) |-- amount: double (nullable = false) |-- savings: double (nullable = false) |-- employment: double (nullable = false) |-- instPercent: double (nullable = false) |-- sexMarried: double (nullable = false) |-- guarantors: double (nullable = false) |-- residenceDuration: double (nullable = false) |-- assets: double (nullable = false) |-- age: double (nullable = false) |-- concCredit: double (nullable = false) |-- apartment: double (nullable = false) |-- credits: double (nullable = false) |-- occupation: double (nullable = false) |-- dependents: double (nullable = false) |-- hasPhone: double (nullable = false) |-- foreign: double (nullable = false) // Display the top 20 rows of DataFrame creditDF.show +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+ |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign| +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+ | 1.0| 0.0| 18.0| 4.0| 2.0|1049.0| 0.0| 1.0| 4.0| 1.0| 0.0| 3.0| 1.0|21.0| 2.0| 0.0| 0.0| 2.0| 0.0| 0.0| 0.0| | 1.0| 0.0| 9.0| 4.0| 0.0|2799.0| 0.0| 2.0| 2.0| 2.0| 0.0| 1.0| 0.0|36.0| 2.0| 0.0| 1.0| 2.0| 1.0| 0.0| 0.0| | 1.0| 1.0| 12.0| 2.0| 9.0| 841.0| 1.0| 3.0| 2.0| 1.0| 0.0| 3.0| 0.0|23.0| 2.0| 0.0| 0.0| 1.0| 0.0| 0.0| 0.0| | 1.0| 0.0| 12.0| 4.0| 0.0|2122.0| 0.0| 2.0| 3.0| 2.0| 0.0| 1.0| 0.0|39.0| 2.0| 0.0| 1.0| 1.0| 1.0| 0.0| 1.0| | 1.0| 0.0| 12.0| 4.0| 0.0|2171.0| 0.0| 2.0| 4.0| 2.0| 0.0| 3.0| 1.0|38.0| 0.0| 1.0| 1.0| 1.0| 0.0| 0.0| 1.0| | 1.0| 0.0| 10.0| 4.0| 0.0|2241.0| 0.0| 1.0| 1.0| 2.0| 0.0| 2.0| 0.0|48.0| 2.0| 0.0| 1.0| 1.0| 1.0| 0.0| 1.0| | 1.0| 0.0| 8.0| 4.0| 0.0|3398.0| 0.0| 3.0| 1.0| 2.0| 0.0| 3.0| 0.0|39.0| 2.0| 1.0| 1.0| 1.0| 0.0| 0.0| 1.0| | 1.0| 0.0| 6.0| 4.0| 0.0|1361.0| 0.0| 1.0| 2.0| 2.0| 0.0| 3.0| 0.0|40.0| 2.0| 1.0| 0.0| 1.0| 1.0| 0.0| 1.0| | 1.0| 3.0| 18.0| 4.0| 3.0|1098.0| 0.0| 0.0| 4.0| 1.0| 0.0| 3.0| 2.0|65.0| 2.0| 1.0| 1.0| 0.0| 0.0| 0.0| 0.0| | 1.0| 1.0| 24.0| 2.0| 3.0|3758.0| 2.0| 0.0| 1.0| 1.0| 0.0| 3.0| 3.0|23.0| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| | 1.0| 0.0| 11.0| 4.0| 0.0|3905.0| 0.0| 2.0| 2.0| 2.0| 0.0| 1.0| 0.0|36.0| 2.0| 0.0| 1.0| 2.0| 1.0| 0.0| 0.0| | 1.0| 0.0| 30.0| 4.0| 1.0|6187.0| 1.0| 3.0| 1.0| 3.0| 0.0| 3.0| 2.0|24.0| 2.0| 0.0| 1.0| 2.0| 0.0| 0.0| 0.0| | 1.0| 0.0| 6.0| 4.0| 3.0|1957.0| 0.0| 3.0| 1.0| 1.0| 0.0| 3.0| 2.0|31.0| 2.0| 1.0| 0.0| 2.0| 0.0| 0.0| 0.0| | 1.0| 1.0| 48.0| 3.0| 10.0|7582.0| 1.0| 0.0| 2.0| 2.0| 0.0| 3.0| 3.0|31.0| 2.0| 1.0| 0.0| 3.0| 0.0| 1.0| 0.0| | 1.0| 0.0| 18.0| 2.0| 3.0|1936.0| 4.0| 3.0| 2.0| 3.0| 0.0| 3.0| 2.0|23.0| 2.0| 0.0| 1.0| 1.0| 0.0| 0.0| 0.0| | 1.0| 0.0| 6.0| 2.0| 3.0|2647.0| 2.0| 2.0| 2.0| 2.0| 0.0| 2.0| 0.0|44.0| 2.0| 0.0| 0.0| 2.0| 1.0| 0.0| 0.0| | 1.0| 0.0| 11.0| 4.0| 0.0|3939.0| 0.0| 2.0| 1.0| 2.0| 0.0| 1.0| 0.0|40.0| 2.0| 1.0| 1.0| 1.0| 1.0| 0.0| 0.0| | 1.0| 1.0| 18.0| 2.0| 3.0|3213.0| 2.0| 1.0| 1.0| 3.0| 0.0| 2.0| 0.0|25.0| 2.0| 0.0| 0.0| 2.0| 0.0| 0.0| 0.0| | 1.0| 1.0| 36.0| 4.0| 3.0|2337.0| 0.0| 4.0| 4.0| 2.0| 0.0| 3.0| 0.0|36.0| 2.0| 1.0| 0.0| 2.0| 0.0| 0.0| 0.0| | 1.0| 3.0| 11.0| 4.0| 0.0|7228.0| 0.0| 2.0| 1.0| 2.0| 0.0| 3.0| 1.0|39.0| 2.0| 1.0| 1.0| 1.0| 0.0| 0.0| 0.0| +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+ 

dataframe初始化以後,你能夠用SQL命令查詢數據了。下面是一些使用Scala DataFrame接口查詢數據的例子:

計算數值型數據的統計信息,包括計數、均值、標準差、最小值和最大值。

 // computes statistics for balance creditDF.describe("balance").show +-------+-----------------+ |summary| balance| +-------+-----------------+ | count| 1000| | mean| 1.577| | stddev|1.257637727110893| | min| 0.0| | max| 3.0| +-------+-----------------+  // compute the avg balance by creditability (the label) creditDF.groupBy("creditability").avg("balance").show +-------------+------------------+ |creditability| avg(balance)| +-------------+------------------+ | 1.0|1.8657142857142857| | 0.0|0.9033333333333333| +-------------+------------------+

你能夠用某個表名將DataFrame註冊爲一張臨時表,而後用SQLContext提供的sql方法執行SQL命令。下面是幾個用sqlContext查詢的例子:

sqlContext.sql("SELECT creditability, avg(balance) as avgbalance, avg(amount) as avgamt, avg(duration) as avgdur FROM credit GROUP BY creditability ").show +-------------+------------------+------------------+------------------+ |creditability| avgbalance| avgamt| avgdur| +-------------+------------------+------------------+------------------+ | 1.0|1.8657142857142857| 2985.442857142857|19.207142857142856| | 0.0|0.9033333333333333|3938.1266666666666| 24.86| +-------------+------------------+------------------+------------------+

提取特徵

爲了構建一個分類模型,你首先須要提取對分類最有幫助的特徵。在德國人信用度的數據集裏,每條樣本用兩個類別來標記——1(可信)和0(不可信)。

每一個樣本的特徵包括如下的字段:

  • 標籤 -> 是否可信:0或者1
  • 特徵 -> {「存款」,「期限」,「歷史記錄」,「目的」,「數額」,「儲蓄」,「是否在職」,「婚姻」,「擔保人」,「居住時間」,「資產」,「年齡」,「歷史信用」,「居住公寓」,「貸款」,「職業」,「監護人」,「是否有電話」,「外籍」}

定義特徵數組

圖片描述

圖片來自:學習Spark

 

爲了在機器學習算法中使用這些特徵,這些特徵通過了變換,存入特徵向量中,即一組表示各個維度特徵值的數值向量。

下圖中,用VectorAssembler方法將每一個維度的特徵都作變換,返回一個新的dataframe。

//define the feature columns to put in the feature vector
    val featureCols = Array("balance", "duration", "history", "purpose", "amount", "savings", "employment", "instPercent", "sexMarried", "guarantors", "residenceDuration", "assets", "age", "concCredit", "apartment", "credits", "occupation", "dependents", "hasPhone", "foreign" ) //set the input and output column names val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features") //return a dataframe with all of the feature columns in a vector column val df2 = assembler.transform( creditDF) // the transform method produced a new column: features. df2.show +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+ |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign| features| +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+ | 1.0| 0.0| 18.0| 4.0| 2.0|1049.0| 0.0| 1.0| 4.0| 1.0| 0.0| 3.0| 1.0|21.0| 2.0| 0.0| 0.0| 2.0| 0.0| 0.0| 0.0|(20,[1,2,3,4,6,7,...|

接着,咱們使用StringIndexer方法返回一個Dataframe,增長了信用度這一列做爲標籤。

//  Create a label column with the StringIndexer val labelIndexer = new StringIndexer().setInputCol("creditability").setOutputCol("label") val df3 = labelIndexer.fit(df2).transform(df2) // the transform method produced a new column: label. df3.show +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+ |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign| features|label| +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+ | 1.0| 0.0| 18.0| 4.0| 2.0|1049.0| 0.0| 1.0| 4.0| 1.0| 0.0| 3.0| 1.0|21.0| 2.0| 0.0| 0.0| 2.0| 0.0| 0.0| 0.0|(20,[1,2,3,4,6,7,...| 0.0|

下圖中,數據集被分爲訓練數據和測試數據兩個部分,70%的數據用來訓練模型,30%的數據用來測試模型。

// split the dataframe into training and test data val splitSeed = 5043 val Array(trainingData, testData) = df3.randomSplit(Array(0.7, 0.3), splitSeed)

訓練模型

圖片描述

接着,咱們按照下列參數訓練一個隨機森林分類器:

  • maxDepth:每棵樹的最大深度。增長樹的深度能夠提升模型的效果,可是會延長訓練時間。
  • maxBins:連續特徵離散化時選用的最大分桶個數,而且決定每一個節點如何分裂。
  • impurity:計算信息增益的指標
  • auto:在每一個節點分裂時是否自動選擇參與的特徵個數
  • seed:隨機數生成種子

模型的訓練過程就是將輸入特徵和這些特徵對應的樣本標籤相關聯的過程。

// create the classifier,  set parameters for training val classifier = new RandomForestClassifier().setImpurity("gini").setMaxDepth(3).setNumTrees(20).setFeatureSubsetStrategy("auto").setSeed(5043) // use the random forest classifier to train (fit) the model val model = classifier.fit(trainingData) // print out the random forest trees model.toDebugString res20: String = res5: String = "RandomForestClassificationModel (uid=rfc_6c4ceb92ba78) with 20 trees Tree 0 (weight 1.0): If (feature 0 <= 1.0) If (feature 10 <= 0.0) If (feature 3 <= 6.0) Predict: 0.0 Else (feature 3 > 6.0) Predict: 0.0 Else (feature 10 > 0.0) If (feature 12 <= 63.0) Predict: 0.0 Else (feature 12 > 63.0) Predict: 0.0 Else (feature 0 > 1.0) If (feature 13 <= 1.0) If (feature 3 <= 3.0) Predict: 0.0 Else (feature 3 > 3.0) Predict: 1.0 Else (feature 13 > 1.0) If (feature 7 <= 1.0) Predict: 0.0 Else (feature 7 > 1.0) Predict: 0.0 Tree 1 (weight 1.0): If (feature 2 <= 1.0) If (feature 15 <= 0.0) If (feature 11 <= 0.0) Predict: 0.0 Else (feature 11 > 0.0) Predict: 1.0 Else (feature 15 > 0.0) If (feature 11 <= 0.0) Predict: 0.0 Else (feature 11 > 0.0) Predict: 1.0 Else (feature 2 > 1.0) If (feature 12 <= 31.0) If (feature 5 <= 0.0) Predict: 0.0 Else (feature 5 > 0.0) Predict: 0.0 Else (feature 12 > 31.0) If (feature 3 <= 4.0) Predict: 0.0 Else (feature 3 > 4.0) Predict: 0.0 Tree 2 (weight 1.0): If (feature 8 <= 1.0) If (feature 6 <= 2.0) If (feature 4 <= 10875.0) Predict: 0.0 Else (feature 4 > 10875.0) Predict: 1.0 Else (feature 6 > 2.0) If (feature 1 <= 36.0) Predict: 0.0 Else (feature 1 > 36.0) Predict: 1.0 Else (feature 8 > 1.0) If (feature 5 <= 0.0) If (feature 4 <= 4113.0) Predict: 0.0 Else (feature 4 > 4113.0) Predict: 1.0 Else (feature 5 > 0.0) If (feature 11 <= 2.0) Predict: 0.0 Else (feature 11 > 2.0) Predict: 0.0 Tree 3 ...

測試模型

接下來,咱們對測試數據進行預測。

// run the  model on test features to get predictions
    val predictions = model.transform(testData) 
    //As you can see, the previous model transform produced a new columns: rawPrediction, probablity and prediction.
    predictions.show

    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+--------------------+--------------------+----------+
    |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|            features|label|       rawPrediction|         probability|prediction|
    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+--------------------+--------------------+----------+
    |          0.0| 0.0| 12.0| 0.0| 5.0|1108.0| 0.0| 3.0| 4.0| 2.0| 0.0| 2.0| 0.0|28.0| 2.0| 1.0| 1.0| 2.0| 0.0| 0.0| 0.0|(20,[1,3,4,6,7,8,...| 1.0|[14.1964586927573...|[0.70982293463786...| 0.0|

而後,咱們用BinaryClassificationEvaluator評估預測的效果,它將預測結果與樣本的實際標籤相比較,返回一個準確度指標(ROC曲線所覆蓋的面積)。本例子中,AUC達到78%。

// create an Evaluator for binary classification, which expects two input columns: rawPrediction and label. val evaluator = new BinaryClassificationEvaluator().setLabelCol("label") // Evaluates predictions and returns a scalar metric areaUnderROC(larger is better). val accuracy = evaluator.evaluate(predictions) accuracy: Double = 0.7824906081835722

使用機器學習管道

咱們接着用管道來訓練模型,可能會取得更好的效果。管道採起了一種簡單的方式來比較各類不一樣組合的參數的效果,這個方法稱爲網格搜索法(grid search),你先設置好待測試的參數,MLLib就會自動完成這些參數的不一樣組合。管道搭建了一條工做流,一次性完成了整個模型的調優,而不是獨立對每一個參數進行調優。

下面咱們就用ParamGridBuilder工具來構建參數網格。

// We use a ParamGridBuilder to construct a grid of parameters to search over val paramGrid = new ParamGridBuilder() .addGrid(classifier.maxBins, Array(25, 28, 31)) .addGrid(classifier.maxDepth, Array(4, 6, 8)) .addGrid(classifier.impurity, Array("entropy", "gini")) .build()

建立並完成一條管道。一條管道由一系列stage組成,每一個stage至關於一個Estimator或是Transformer。

val steps: Array[PipelineStage] = Array(classifier) val pipeline = new Pipeline().setStages(steps)

咱們用CrossValidator類來完成模型篩選。CrossValidator類使用一個Estimator類,一組ParamMaps類和一個Evaluator類。注意,使用CrossValidator類的開銷很大。

// Evaluate model on test instances and compute test error val evaluator = new BinaryClassificationEvaluator() .setLabelCol("label") val cv = new CrossValidator() .setEstimator(pipeline) .setEvaluator(evaluator) .setEstimatorParamMaps(paramGrid) .setNumFolds(10)

管道在參數網格上不斷地爬行,自動完成了模型優化的過程:對於每一個ParamMap類,CrossValidator訓練獲得一個Estimator,而後用Evaluator來評價結果,而後用最好的ParamMap和整個數據集來訓練最優的Estimator。

圖片描述

// When fit is called, the stages are executed in order. // Fit will run cross-validation, and choose the best set of parameters //The fitted model from a Pipeline is an PipelineModel, which consists of fitted models and transformers val pipelineFittedModel = cv.fit(trainingData)

如今,咱們能夠用管道訓練獲得的最優模型進行預測,將預測結果與標籤作比較。預測結果取得了82%的準確率,相比以前78%的準確率有提升。

//  call tranform to make predictions on test data. The fitted model will use the best model found val predictions = pipelineFittedModel.transform(testData) val accuracy = evaluator.evaluate(predictions) Double = 0.8204386232104784 val rm2 = new RegressionMetrics( predictions.select("prediction", "label").rdd.map(x => (x(0).asInstanceOf[Double], x(1).asInstanceOf[Double]))) println("MSE: " + rm2.meanSquaredError) println("MAE: " + rm2.meanAbsoluteError) println("RMSE Squared: " + rm2.rootMeanSquaredError) println("R Squared: " + rm2.r2) println("Explained Variance: " + rm2.explainedVariance + "\n") MSE: 0.2575250836120402 MAE: 0.25752508361204013 RMSE Squared: 0.5074692932700856 R Squared: -0.1687988628287138 Explained Variance: 0.15466269952237702
相關文章
相關標籤/搜索