spark 決策樹分類 DecisionTreeClassifier

決策樹分類是一個非機率模型,測試數據集用的是網上公開的泰坦尼克號乘客數據,用決策樹DecisionTreeClassifier的數據挖掘算法來經過三個參數,Pclass,Sex,Age,三個參數來預測乘客的獲救率。
pom.xmljava

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
  xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <groupId>com.penngo.spark.ml</groupId>
  <artifactId>sparkml</artifactId>
  <packaging>jar</packaging>
  <version>1.0-SNAPSHOT</version>
  <name>sparkml</name>
  <url>http://maven.apache.org</url>
  <properties>
    <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
    <java.version>1.8</java.version>
  </properties>
  <dependencies>
    <dependency>
      <groupId>junit</groupId>
      <artifactId>junit</artifactId>
      <version>3.8.1</version>
      <scope>test</scope>
    </dependency>
	<dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-core_2.11</artifactId>
      <version>2.2.3</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-sql_2.11</artifactId>
      <version>2.2.3</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-mllib_2.11</artifactId>
      <version>2.2.3</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-streaming_2.11</artifactId>
      <version>2.2.3</version>
    </dependency>
  </dependencies>
	<build>
    <plugins>
      <plugin>
        <groupId>org.apache.maven.plugins</groupId>
        <artifactId>maven-compiler-plugin</artifactId>
        <version>3.7.0</version>
        <configuration>
          <source>1.8</source>
          <target>1.8</target>
          <encoding>UTF-8</encoding>
        </configuration>
      </plugin>
    </plugins>
    </build>
</project>

DecisionTreeClassification.javalinux

package com.penngo.spark.ml.main;

import org.apache.log4j.Logger;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.*;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import java.io.File;
import org.apache.spark.sql.functions;
import static org.apache.spark.sql.types.DataTypes.DoubleType;

/**
 * spark 決策樹分類 DecisionTreeClassifier
 *
 */
public class DecisionTreeClassification {
    private static Logger log = Logger.getLogger(DecisionTreeClassification.class);
    private static SparkSession spark = null;



    public static void initSpark(){
        if (spark == null) {
            String os = System.getProperty("os.name").toLowerCase();
            // linux上運行
            if(os.indexOf("windows") == -1){
                spark = SparkSession
                        .builder()
                        .appName("DecisionTreeClassification")
                        .getOrCreate();
            }
            // window上運行,本機調試
            else{
                System.setProperty("hadoop.home.dir", "D:/hadoop/hadoop-2.7.6");
                System.setProperty("HADOOP_USER_NAME", "hadoop");
                spark = SparkSession
                        .builder()
                        .appName("DecisionTreeClassification" ).master("local[3]")
                        .getOrCreate();
            }
        }
        log.warn("spark.conf().getAll()=============" + spark.conf().getAll());
    }

    public static void run(){
        String dataPath = new File("").getAbsolutePath() + "/data/titanic.txt";
        Dataset<Row> data = spark.read().option("header", "true").csv(dataPath);
        data.show();
        //data.describe()
        //Dataset<Row> datana2 = data.na().fill(ImmutableMap.of("age", "30", "ticket", "1111"));

        Dataset<Row> meanDataset = data.select(functions.mean("age").as("mage"));
        Double mage = meanDataset.first().getAs("mage");
        // 字符串轉換爲數據,處理空值
        Dataset<Row> data1 = data.select(
                functions.col("user_id"),
                functions.col("survived").cast(DoubleType).as("label"),
                functions.when(functions.col("pclass").equalTo("1st"), 1)
                        .when(functions.col("pclass").equalTo("2nd"), 2)
                        .when(functions.col("pclass").equalTo("3rd"), 3)
                        .cast(DoubleType).as("pclass1"),
                functions.when(functions.col("age").equalTo("NA"), mage.intValue()).otherwise(functions.col("age")).cast(DoubleType).as("age1"),
                functions.when(functions.col("sex").equalTo("female"), 0).otherwise(1).as("sex")
        );

        VectorAssembler assembler = new VectorAssembler()
                .setInputCols(new String[]{"pclass1", "age1", "sex"})
                .setOutputCol("features");
        Dataset<Row> data2 = assembler.transform(data1);
        data2.show();
        // 索引標籤,將元數據添加到標籤列中
        StringIndexerModel labelIndexer = new StringIndexer()
                .setInputCol("label")
                .setOutputCol("indexedLabel")
                .fit(data2);
        // 自動識別分類的特徵,並對它們進行索引
        // 具備大於5個不一樣的值的特徵被視爲連續。
        VectorIndexerModel featureIndexer = new VectorIndexer()
                .setInputCol("features")
                .setOutputCol("indexedFeatures")
                //.setMaxCategories(3)
                .fit(data2);
        // 將數據分爲訓練和測試集(30%進行測試)
        Dataset<Row>[] splits = data2.randomSplit(new double[]{0.7, 0.3});
        Dataset<Row> trainingData = splits[0];
        Dataset<Row> testData = splits[1];

        // 訓練決策樹模型
        DecisionTreeClassifier dt = new DecisionTreeClassifier()
                .setLabelCol("indexedLabel")
                .setFeaturesCol("indexedFeatures");
        //.setImpurity("entropy") // Gini不純度,entropy熵
        //.setMaxBins(100) // 離散化"連續特徵"的最大劃分數
        //.setMaxDepth(5) // 樹的最大深度
        //.setMinInfoGain(0.01) //一個節點分裂的最小信息增益,值爲[0,1]
        //.setMinInstancesPerNode(10) //每一個節點包含的最小樣本數
        //.setSeed(123456)

        IndexToString labelConverter = new IndexToString()
                .setInputCol("prediction")
                .setOutputCol("predictedLabel")
                .setLabels(labelIndexer.labels());


        // Chain indexers and tree in a Pipeline.
        Pipeline pipeline = new Pipeline()
                .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter});

        // 訓練模型
        PipelineModel model = pipeline.fit(trainingData);

        // 預測數據
        Dataset<Row> predictions = model.transform(testData);

        predictions.select("user_id", "features", "label", "prediction").show();
        //predictions.select("predictedLabel", "label", "features").show(5);

        // 計算錯誤率
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
                .setLabelCol("indexedLabel")
                .setPredictionCol("prediction")
                .setMetricName("accuracy");
        double accuracy = evaluator.evaluate(predictions);
        System.out.println("Test Error = " + (1.0 - accuracy));

        // 查看決策樹
        DecisionTreeClassificationModel treeModel =
                (DecisionTreeClassificationModel) (model.stages()[2]);
        System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());
        // $example off$

        spark.stop();
    }
    public static void main(String[] args){
        initSpark();
        run();
    }
}

基礎數據算法

過濾、特徵化後的數據sql


預測結果apache


預測錯誤率和預測模型windows

相關文章
相關標籤/搜索