梯度加強樹(GBT)是使用決策樹組合的流行迴歸方法git
相對於Random forest 來講GBT在實際應用中,效果更好算法
直接上代碼sql
package mllib import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.GBTClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature._ import org.apache.spark.sql.SparkSession /** * Created by dongdong on 17/7/10. */ case class Fearture_One( cid: String, population_gender: String, population_age: Double, population_registered_gps_city: String, population_education_nature: String, population_university_level: String, sociality_channel_type: String, action_registered_channel: String, action_this_month_once_week_average_login_count: Double, population_censu_city: String, population_gps_city: String, population_own_cell_city: String, population_rank1_cell_city: String, population_rank1_cell_cnt: Double, population_rank2_cell_city: String, population_rank2_cell_cnt: Double, population_rank3_cell_city: String, population_rank3_cell_cnt: Double, population_gps_censu_flag: Double, population_own_censu_flag: Double, population_gps_own_flag: Double, population_own_txl_flag: Double, population_gps_txl_flag: Double, population_censu_txl_flag: Double, population_cnt_7day_province: Double, population_cnt_7day_city: Double, population_cnt_login: Double, population_before_apply_city: String, population_after_apply_city: String, population_before_in_apply_address: Double, population_before_after_apply_address: Double, population_in_after_apply_address: Double, population_re_address_steady: String, population_apply_address_steady: String, population_score_fake_gps: Double, population_score_fake_contacts: Double, text: String, flag: String ) object GBT_Profile { def main(args: Array[String]): Unit = { val inpath1 = "/Users/ant_git/src/data/user_profile_train/part-00000" val spark = SparkSession .builder() .master("local[3]") .appName("GBT_Profile") .getOrCreate() import spark.implicits._ //read data and transform datafram val originalData = spark.sparkContext .textFile(inpath1) .map(line => { val arr = line.split("\001") val cid = arr(0) val population_gender = arr(3).replace("\\N", "N") val population_age = arr(4).replace("\\N", "0").toDouble val population_registered_gps_city = arr(7).replace("\\N", "N") val population_education_nature = arr(10).replace("\\N", "N") val population_university_level = arr(11).replace("\\N", "N") val sociality_channel_type = arr(13).replace("\\N", "N") val action_registered_channel = arr(44).replace("\\N", "N") val action_this_month_once_week_average_login_count = arr(54).replace("\\N", "0").toDouble val population_censu_city = arr(63).replace("\\N", "N") val population_gps_city = arr(64).replace("\\N", "N") // val population_jz_city = arr(65).replace("\\N", "N") // val population_ip_city = arr(66).replace("\\N", "N") val population_own_cell_city = arr(67).replace("\\N", "N") val population_rank1_cell_city = arr(68).replace("\\N", "N") val population_rank1_cell_cnt = arr(69).replace("\\N", "0").toDouble val population_rank2_cell_city = arr(70).replace("\\N", "N") val population_rank2_cell_cnt = arr(71).replace("\\N", "0").toDouble val population_rank3_cell_city = arr(72).replace("\\N", "N") val population_rank3_cell_cnt = arr(73).replace("\\N", "0").toDouble //val population_jxl_call_max_city = arr(74).replace("\\N", "N") // val population_jxl_call_max_city_cnt = arr(75).replace("\\N", "0").toDouble //val population_anzhuo_30day_max_city = arr(76).replace("\\N", "N") //val population_anzhuo_30day_max_city_cnt = arr(77).replace("\\N", "0").toDouble val population_gps_censu_flag = arr(78).replace("\\N", "0").toDouble //val population_gps_jxl_flag = arr(79).replace("\\N", "0").toDouble //val population_gps_jz_flag = arr(80).replace("\\N", "0").toDouble //val population_ip_censu_flag = arr(81).replace("\\N", "0").toDouble // val population_ip_jxl_flag = arr(82).replace("\\N", "0").toDouble //val population_ip_jz_flag = arr(83).replace("\\N", "0").toDouble val population_own_censu_flag = arr(84).replace("\\N", "0").toDouble //val population_own_jxl_flag = arr(85).replace("\\N", "0").toDouble //val population_own_jz_flag = arr(86).replace("\\N", "0").toDouble val population_gps_own_flag = arr(87).replace("\\N", "0").toDouble //val population_gps_ip_flag = arr(88).replace("\\N", "0").toDouble //val population_ip_own_flag = arr(89).replace("\\N", "0").toDouble //val population_ip_txl_flag = arr(90).replace("\\N", "0").toDouble val population_own_txl_flag = arr(91).replace("\\N", "0").toDouble val population_gps_txl_flag = arr(92).replace("\\N", "0").toDouble val population_censu_txl_flag = arr(93).replace("\\N", "0").toDouble //val population_jxl_txl_flag = arr(94).replace("\\N", "0").toDouble //val population_jz_txl_flag = arr(95).replace("\\N", "0").toDouble val population_cnt_7day_province = arr(96).replace("\\N", "0").toDouble val population_cnt_7day_city = arr(97).replace("\\N", "0").toDouble val population_cnt_login = arr(102).replace("\\N", "0").toDouble val population_before_apply_city = arr(107).replace("\\N", "N") val population_after_apply_city = arr(108).replace("\\N", "N") val population_before_in_apply_address = arr(111).replace("\\N", "0").toDouble val population_before_after_apply_address = arr(112).replace("\\N", "0").toDouble val population_in_after_apply_address = arr(113).replace("\\N", "0").toDouble val population_re_address_steady = arr(116).replace("\\N", "N") val population_apply_address_steady = arr(117).replace("\\N", "N") val population_score_fake_gps = arr(127).replace("\\N", "0").toDouble val population_score_fake_contacts = arr(128).replace("\\N", "0").toDouble val text = population_gender + "|" + population_registered_gps_city + "|" + population_education_nature + "|" + population_university_level + "|" + sociality_channel_type + "|" + action_registered_channel + "|" + population_censu_city + "|" + population_gps_city + "|" + population_own_cell_city + "|" + population_rank1_cell_city + "|" + population_rank2_cell_city + "|" + population_rank3_cell_city + "|" + population_before_apply_city + "|" + population_after_apply_city + "|" + population_re_address_steady + "|" + population_apply_address_steady val flag = arr(141) Fearture_One( cid: String, population_gender: String, population_age: Double, population_registered_gps_city: String, population_education_nature: String, population_university_level: String, sociality_channel_type: String, action_registered_channel: String, action_this_month_once_week_average_login_count: Double, population_censu_city: String, population_gps_city: String, population_own_cell_city: String, population_rank1_cell_city: String, population_rank1_cell_cnt: Double, population_rank2_cell_city: String, population_rank2_cell_cnt: Double, population_rank3_cell_city: String, population_rank3_cell_cnt: Double, population_gps_censu_flag: Double, population_own_censu_flag: Double, population_gps_own_flag: Double, population_own_txl_flag: Double, population_gps_txl_flag: Double, population_censu_txl_flag: Double, population_cnt_7day_province: Double, population_cnt_7day_city: Double, population_cnt_login: Double, population_before_apply_city: String, population_after_apply_city: String, population_before_in_apply_address: Double, population_before_after_apply_address: Double, population_in_after_apply_address: Double, population_re_address_steady: String, population_apply_address_steady: String, population_score_fake_gps: Double, population_score_fake_contacts: Double, text: String, flag: String ) } ).toDS //label to indexer val labelIndexer = new StringIndexer() .setInputCol("flag") .setOutputCol("indexedLabel") .fit(originalData) //splits words val tokenizer = new RegexTokenizer() .setInputCol("text") .setOutputCol("words") .setPattern("\\|") //words to vector val word2Vec = new Word2Vec() .setInputCol("words") .setOutputCol("word2feature") .setVectorSize(100) //.setMinCount(1) .setMaxIter(10) //array fields val arr = Array("population_age", "action_this_month_once_week_average_login_count", "population_rank1_cell_cnt", "population_rank2_cell_cnt", "population_rank3_cell_cnt", "population_gps_censu_flag", "population_own_censu_flag", "population_gps_own_flag", "population_own_txl_flag", "population_gps_txl_flag", "population_censu_txl_flag", "population_cnt_7day_province", "population_cnt_7day_city", "population_cnt_login", "population_before_in_apply_address", "population_before_after_apply_address", "population_in_after_apply_address", "population_score_fake_gps", "population_score_fake_contacts", "word2feature" ) //merge fields to Verctor val vectorAssembler = new VectorAssembler() .setInputCols(arr) .setOutputCol("assemblerVector") //creat GBT val gbt = new GBTClassifier() .setLabelCol("indexedLabel") .setFeaturesCol("assemblerVector") //set iterator .setMaxIter(25) //set tree depth .setMaxDepth(5) val labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels) val Array(trainingData, testData) = originalData.randomSplit(Array(0.8, 0.2)) val pipeline = new Pipeline().setStages(Array(labelIndexer, tokenizer, word2Vec, vectorAssembler, gbt, labelConverter)) val model = pipeline.fit(originalData) val predictionResultDF = model.transform(testData) predictionResultDF.show(false) val label_1 = predictionResultDF.select("cid", "flag", "predictedLabel") .filter($"flag" === 1) .count() val correct_1 = predictionResultDF.select("cid", "flag", "predictedLabel") .filter($"flag" === $"predictedLabel") .filter($"predictedLabel" === 1).count() val correct_0 = predictionResultDF.select("cid", "flag", "predictedLabel") .filter($"flag" === $"predictedLabel") .filter($"predictedLabel" === 0).count() val predicted_1 = predictionResultDF.select("cid", "predictedLabel") .filter($"predictedLabel" === 1) .repartition(1).write.format("csv").save("/Users/ant_git/Antifraud/src/data/predict/") val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictionResultDF) val error = 1.0 - accuracy println("Test Error = " + (1.0 - accuracy)) spark.stop() } }
總結:算法是別人封裝好的,最重要的是特徵如何進行處理,好的特徵,很簡單的算法均可以進行分類,很差的特徵,再好的模型也很難有好的效果,因此如何進行特徵的選擇,對於機器學習來講是很是重要的apache