【R語言學習筆記】 Day1 CART 邏輯迴歸、分類樹以及隨機森林的應用及對比

 

1. 目的:根據人口普查數據來預測收入(預測每一個個體年收入是否超過$50,000)api

 

2. 數據來源:1994年美國人口普查數據,數據中共含31978個觀測值,每一個觀測值表明一個個體dom

 

3. 變量介紹:測試

(1)age: 年齡(以年表示)rest

(2)workclass: 工做類別/性質 (e.g., 國家機關工做人員、當地政府工做人員、無收入人員等)orm

(3)education: 受教育水平 (e.g., 小學、初中、高中、本科、碩士、博士等)blog

(4)maritalstatus: 婚姻狀態(e.g., 未婚、離異等)ip

(5)occupation: 工做類型 (e.g., 行政/文員、農業養殖人員、銷售人員等)it

(6)relationship: 家庭身份 (e.g., 丈夫、妻子、孩子等)io

(7)race: 種族 table

(8)sex: 性別

(9)capitalgain: 1994年的資本收入 (買賣股票、債券等)

(10)capitalloss: 1994年的資本支出 (買賣股票、債券等)

(11)hoursperweek: 每週工做時長

(12)nativecountry: 國籍

(13)over50k: 1994年整年工資是否超過$50,000

 

 

4. 應用及分析

census <- read.csv("census.csv") #讀取文件

  

library(caTools) # 加載caTools包
# 將數據分爲測試集和訓練集
set.seed(2000)
spl <- sample.split(census$over50k, SplitRatio = 0.6)
census.train <- subset(census, spl == T) # 測試集
census.test <- subset(census, spl == F) # 訓練集

  

# 構建邏輯迴歸模型
census.logistic <- glm(over50k ~ ., data = census.train, family = 'binomial')
summary(census.logistic) # 查看模型擬合結果

 

 

 

# 在臨界值爲0.5的狀況下,邏輯迴歸模型應用到測試集的準確性
## method1
census.logistic.pred <- predict(census.logistic, newdata = census.test, type = 'response')
library(caret)
confusionMatrix(as.factor(ifelse(census.logistic.pred >= 0.5, " >50K", " <=50K")), as.factor(census.test$over50k))

## method2
table(census.test$over50k, census.logistic.pred>= 0.5)
sum(diag(table(census.test$over50k, census.logistic.pred>= 0.5)))/nrow(census.test) #0.8552


# 測試集的基礎準確性
table(census.test$over50k)/nrow(census.test) #0.759

  

# ROC 以及 AUC
library(ROCR)
census.pred <- prediction(census.logistic.pred, census.test$over50k)
census.perf <- performance(census.pred, 'tpr', 'fpr')
plot(census.perf, colorize = T) #ROC curve
as.numeric(performance(census.pred, 'auc')@y.values) #AUC value is 0.9061598

 

雖然邏輯迴歸模型準確率高達0.8572,且變量的顯著性有助於咱們判斷個體的收入狀況;可是在自變量中的分類變量類別太多的狀況下,咱們沒法判斷哪些變量更重要。

所以,接下來構建CART模型。

 

# 默認的CART模型
library(rpart)
library(rpart.plot)
census.cart <- rpart(over50k ~ ., data = census.train, method = 'class')
prp(census.cart) # 做圖

 

# 模型準確性
census.cart.pred <- predict(census.cart, newdata = census.test, type = 'class')
## method1
table(census.test$over50k, census.cart.pred)
sum(diag(table(census.test$over50k, census.cart.pred)))/nrow(census.test)
## method2
confusionMatrix(census.cart.pred, as.factor(census.test$over50k)) # 模型準確性爲0.8474

 

# ROC 以及 AUC
census.cart.pred2 <- predict(census.cart, newdata = census.test)
census.cart.pred2
census.cart.pred3 <- prediction(census.cart.pred2[,2], census.test$over50k)
census.cart.perf <- performance(census.cart.pred3, 'tpr', 'fpr')
plot(census.cart.perf, colorize = T) # ROC

as.numeric(performance(census.cart.pred3, 'auc')@y.values) #AUC value is 0.8470256

 

# 隨機森林模型
set.seed(1)
census.train.small <- census.train[sample(nrow(census.train), 2000),]
## 構建隨機森林模型以前先減少訓練集樣本數量。
## 由於隨機森林過程當中包含大量運算過程,小樣本更益於模型的創建

library(randomForest)
census.train.small.rf <- randomForest(over50k ~ ., data = census.train.small)

# 模型預測
census.train.small.rf.pred <- predict(census.train.small.rf, newdata = census.test)

# 模型準確性
confusionMatrix(census.train.small.rf.pred, as.factor(census.test$over50k)) # 0.8533

  

 由於隨機森林模型是一系列分類決策樹的集合,所以與分類決策樹相比,隨機森林模型的解釋性稍差,但仍可用一些方法來衡量變量的重要性

# 方法一:統計隨機過程當中每一個變量出現的次數
vu <- varUsed(census.train.small.rf, count=TRUE)
vusrted <- sort(vu, decreasing = FALSE, index.return = TRUE)
# draw a Cleveland dot plot
dotchart(vusorted$x, names(census.train.small.rf$forest$xlevels[vusorted$ix]))

其中,age出現次數最多,sex出現次數最少。

 

# 方法二:比較平均Gini指數的降低程度
varImpPlot(census.train.small.rf)

 

其中,occupation、education、age的平均Gini指數減小的最多,sex的平均Gini指數減小的最少

 

# 改進的CART模型(考慮cp值)
library(caret)
library(lattice)
library(ggplot2)
library(e1071)

# 找出使得準確率最高的cp值
set.seed(2)
numFolds <- trainControl(method = 'cv', number = 10)
cpGrid <- expand.grid(.cp = seq(0.002,0.1,0.002))
train(over50k ~ ., data = census.train,
      method = 'rpart', trControl = numFolds, tuneGrid = cpGrid) # cp = 0.002時模型準確度最高

# 構建新的CART模型(cp=0.002)
census.bestTree <- rpart(over50k ~ ., data = census.train, method = 'class', cp = 0.002)
prp(census.bestTree) # 做圖

# 模型預測
predCV <- predict(census.bestTree, newdata = census.test, type = 'class')

# 計算新模型的準確率
## method1
table(census.test$over50k, predCV)
sum(diag(table(census.test$over50k, predCV)))/nrow(census.test)
## method2
confusionMatrix(predCV, as.factor(census.test$over50k)) # 0.8612

 

 

考慮cp值之後的CART模型的準確性比默認模型高了1%左右,可是模型明顯複雜了更多,所以須要在模型簡潔性及準確性之間作出權衡。

本案例中,默認模型足夠簡潔且準確度也很高,因此傾向使用默認模型。

相關文章
相關標籤/搜索