決策樹分類是一個非機率模型,測試數據集用的是網上公開的泰坦尼克號乘客數據,用決策樹DecisionTreeClassifier的數據挖掘算法來經過三個參數,Pclass,Sex,Age,三個參數來預測乘客的獲救率。
pom.xmljava
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>com.penngo.spark.ml</groupId> <artifactId>sparkml</artifactId> <packaging>jar</packaging> <version>1.0-SNAPSHOT</version> <name>sparkml</name> <url>http://maven.apache.org</url> <properties> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding> <java.version>1.8</java.version> </properties> <dependencies> <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <version>3.8.1</version> <scope>test</scope> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.11</artifactId> <version>2.2.3</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_2.11</artifactId> <version>2.2.3</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-mllib_2.11</artifactId> <version>2.2.3</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming_2.11</artifactId> <version>2.2.3</version> </dependency> </dependencies> <build> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <version>3.7.0</version> <configuration> <source>1.8</source> <target>1.8</target> <encoding>UTF-8</encoding> </configuration> </plugin> </plugins> </build> </project>
DecisionTreeClassification.javalinux
package com.penngo.spark.ml.main; import org.apache.log4j.Logger; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.DecisionTreeClassificationModel; import org.apache.spark.ml.classification.DecisionTreeClassifier; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.ml.feature.*; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import java.io.File; import org.apache.spark.sql.functions; import static org.apache.spark.sql.types.DataTypes.DoubleType; /** * spark 決策樹分類 DecisionTreeClassifier * */ public class DecisionTreeClassification { private static Logger log = Logger.getLogger(DecisionTreeClassification.class); private static SparkSession spark = null; public static void initSpark(){ if (spark == null) { String os = System.getProperty("os.name").toLowerCase(); // linux上運行 if(os.indexOf("windows") == -1){ spark = SparkSession .builder() .appName("DecisionTreeClassification") .getOrCreate(); } // window上運行,本機調試 else{ System.setProperty("hadoop.home.dir", "D:/hadoop/hadoop-2.7.6"); System.setProperty("HADOOP_USER_NAME", "hadoop"); spark = SparkSession .builder() .appName("DecisionTreeClassification" ).master("local[3]") .getOrCreate(); } } log.warn("spark.conf().getAll()=============" + spark.conf().getAll()); } public static void run(){ String dataPath = new File("").getAbsolutePath() + "/data/titanic.txt"; Dataset<Row> data = spark.read().option("header", "true").csv(dataPath); data.show(); //data.describe() //Dataset<Row> datana2 = data.na().fill(ImmutableMap.of("age", "30", "ticket", "1111")); Dataset<Row> meanDataset = data.select(functions.mean("age").as("mage")); Double mage = meanDataset.first().getAs("mage"); // 字符串轉換爲數據,處理空值 Dataset<Row> data1 = data.select( functions.col("user_id"), functions.col("survived").cast(DoubleType).as("label"), functions.when(functions.col("pclass").equalTo("1st"), 1) .when(functions.col("pclass").equalTo("2nd"), 2) .when(functions.col("pclass").equalTo("3rd"), 3) .cast(DoubleType).as("pclass1"), functions.when(functions.col("age").equalTo("NA"), mage.intValue()).otherwise(functions.col("age")).cast(DoubleType).as("age1"), functions.when(functions.col("sex").equalTo("female"), 0).otherwise(1).as("sex") ); VectorAssembler assembler = new VectorAssembler() .setInputCols(new String[]{"pclass1", "age1", "sex"}) .setOutputCol("features"); Dataset<Row> data2 = assembler.transform(data1); data2.show(); // 索引標籤,將元數據添加到標籤列中 StringIndexerModel labelIndexer = new StringIndexer() .setInputCol("label") .setOutputCol("indexedLabel") .fit(data2); // 自動識別分類的特徵,並對它們進行索引 // 具備大於5個不一樣的值的特徵被視爲連續。 VectorIndexerModel featureIndexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexedFeatures") //.setMaxCategories(3) .fit(data2); // 將數據分爲訓練和測試集(30%進行測試) Dataset<Row>[] splits = data2.randomSplit(new double[]{0.7, 0.3}); Dataset<Row> trainingData = splits[0]; Dataset<Row> testData = splits[1]; // 訓練決策樹模型 DecisionTreeClassifier dt = new DecisionTreeClassifier() .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures"); //.setImpurity("entropy") // Gini不純度,entropy熵 //.setMaxBins(100) // 離散化"連續特徵"的最大劃分數 //.setMaxDepth(5) // 樹的最大深度 //.setMinInfoGain(0.01) //一個節點分裂的最小信息增益,值爲[0,1] //.setMinInstancesPerNode(10) //每一個節點包含的最小樣本數 //.setSeed(123456) IndexToString labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels()); // Chain indexers and tree in a Pipeline. Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter}); // 訓練模型 PipelineModel model = pipeline.fit(trainingData); // 預測數據 Dataset<Row> predictions = model.transform(testData); predictions.select("user_id", "features", "label", "prediction").show(); //predictions.select("predictedLabel", "label", "features").show(5); // 計算錯誤率 MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") .setMetricName("accuracy"); double accuracy = evaluator.evaluate(predictions); System.out.println("Test Error = " + (1.0 - accuracy)); // 查看決策樹 DecisionTreeClassificationModel treeModel = (DecisionTreeClassificationModel) (model.stages()[2]); System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); // $example off$ spark.stop(); } public static void main(String[] args){ initSpark(); run(); } }
基礎數據算法
過濾、特徵化後的數據sql
預測結果apache
預測錯誤率和預測模型windows