xgboost的SparkWithDataFrame版本實現

  再xgboost的源碼中有xgboost的SparkWithDataFrame的實現,以下:https://github.com/dmlc/xgboost/tree/master/jvm-packages。可是因爲各類各樣的緣由吧,這些代碼在個人IDE裏面編譯不過,所以又寫了以下代碼以供之後查閱使用。git

 

package xgboost

import ml.dmlc.xgboost4j.scala.spark.{XGBoost, XGBoostModel}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.sql.{Row, DataFrame, SparkSession}

object App{
  def main(args: Array[String]): Unit ={
    val trainPath: String = "xxx/train.txt"
    val testPath: String = "xxx/test.txt"
    val binaryModelPath: String = "xxx/model.binary"
    val textModelPath: String = "xxx/model.txt"
    val spark = SparkSession
      .builder()
      .master("yarn")
      .getOrCreate()

    // define xgboost parameters
    val maxDepth = 3
    val numRound = 4
    val nworker = 1
    val paramMap = List(
      "eta" -> 0.1,
      "max_depth" -> maxDepth,
      "objective" -> "binary:logistic").toMap

    //read libsvm file
    var dfTrain = spark.read.format("libsvm").load(trainPath).toDF("labelCol", "featureCol")
    var dfTest = spark.read.format("libsvm").load(testPath).toDF("labelCol", "featureCol")
    dfTrain.show(true)
    printf("begin...")
    val model:XGBoostModel = XGBoost.trainWithDataFrame(dfTrain, paramMap, numRound, nworker,
      useExternalMemory = true,
      featureCol = "featureCol", labelCol = "labelCol",
      missing = 0.0f)

    //predict the test set
    val predict:DataFrame = model.transform(dfTest)
    val scoreAndLabels = predict.select(model.getPredictionCol, model.getLabelCol)
      .rdd
      .map{case Row(score:Double, label:Double) => (score, label)}

    //get the auc
    val metric = new BinaryClassificationMetrics(scoreAndLabels)
    val auc = metric.areaUnderROC()
    println("auc:" + auc)

    //save model
    this.saveBinaryModel(model, spark, binaryModelPath)
    this.saveTextModel(model, spark, textModelPath, numRound, maxDepth)
  }

  def saveBinaryModel(model:XGBoostModel, spark: SparkSession, path: String): Unit = {
    model.saveModelAsHadoopFile(path)(spark.sparkContext)
  }

  def saveTextModel(model:XGBoostModel, spark: SparkSession, path: String, numRound: Int, maxDepth: Int): Unit = {
    val dumpModel = model
      .booster
      .getModelDump()
      .toList
      .zipWithIndex
      .map(x => s"booster:[${x._2}]\n${x._1}")

    val header = s"numRound: $numRound, maxDepth: $maxDepth"
    print(dumpModel)
    import spark.implicits._
    val text: List[String] = header +: dumpModel
      text.toDF
        .coalesce(1)
        .write
        .mode("overwrite")
        .text(path)
  }
}

  其中:github

  1.訓練集和測試集都是libsvm格式,以下所示:sql

1 3:1 10:1 11:1 21:1 30:1 34:1 36:1 40:1 41:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 117:1 124:1
0 3:1 10:1 20:1 21:1 23:1 34:1 36:1 39:1 41:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 116:1 120:1apache

  2.最終生成的模型以下所示:jvm

numRound: 4, maxDepth: 3
booster:[0]
0:[f29<2] yes=1,no=2,missing=2
    1:leaf=0.152941
    2:leaf=-0.191209

booster:[1]
0:[f29<2] yes=1,no=2,missing=2
    1:leaf=0.141901
    2:leaf=-0.174499

booster:[2]
0:[f29<2] yes=1,no=2,missing=2
    1:leaf=0.132731
    2:leaf=-0.161685

booster:[3]
0:[f29<2] yes=1,no=2,missing=2
    1:leaf=0.124972
    2:leaf=-0.15155

  相關解釋:」numRound: 4, maxDepth: 3」表示生成樹的個數爲4,樹的最大深度爲3;booster[n]表示第n棵樹;如下保存樹的結構,0號節點爲根節點,每一個節點有兩個子節點,節點序號按層序技術,即1號和2號節點爲根節點0號節點的子節點,相同層的節點有相同縮進,且比父節點多一級縮進。
  在節點行,首先聲明節點序號,中括號裏寫明該節點採用第幾個特徵(如f29即爲訓練數據的第29個特徵),同時代表特徵值劃分條件,「[f29<2] yes=1,no=2,missing=2」:表示f29號特徵大於2時該樣本劃分到1號葉子節點,f29>=2時劃分到2號葉子節點,當沒有該特徵(None)劃分到2號葉子節點。oop

  3.預測的結果以下:測試

|labelCol|featureCol                                                                                                                                                  |probabilities                          |prediction|
|1.0     |(126,[2,9,10,20,29,33,35,39,40,52,57,64,68,76,85,87,91,94,101,104,116,123],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|[0.3652743101119995,0.6347256898880005]|1.0       |
|0.0     |(126,[2,9,19,20,22,33,35,38,40,52,55,64,68,76,85,87,91,94,101,105,115,119],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|[0.6635029911994934,0.3364970088005066]|0.0       |
相關文章
相關標籤/搜索