機器學習算法—Regression Tree 迴歸樹

1. 引言

AI時代,機器學習算法成爲了研究、應用的熱點。當前,最火的兩類算法莫過於神經網絡算法(CNN、RNN、LSTM等)與樹形算法(隨機森林、GBDT、XGBoost等),樹形算法的基礎就是決策樹。決策樹因其易理解、易構建、速度快的特性,被普遍應用於統計學、數據挖掘、機器學習領域。所以,對決策樹的學習,是機器學習之路必不可少的一步。git

根據處理數據類型的不一樣,決策樹又分爲兩類:分類決策樹迴歸決策樹,前者可用於處理離散型數據,後者可用於處理連續型數據,下面的英文引用自維基百科github

Classification tree analysis is when the predicted outcome is the class to which the data belongs.算法

Regression tree analysis is when the predicted outcome can be considered a real number (e.g. the price of a house, or a patient's length of stay in a hospital).網絡

網絡上有關於分類決策樹的介紹可謂數不勝數,可是對迴歸決策樹(迴歸樹)的介紹卻少之又少。李航教授的統計學習方法 對迴歸樹有一個簡單介紹,惋惜篇幅較短,沒有給出一個具體實例;Google搜索迴歸樹,有一篇介紹迴歸樹的博客(點擊),該博客所舉的實例有誤,其過程事實上是基於殘差的GBDT框架

基於以上緣由,本文簡單介紹了迴歸樹(Regression Tree),簡單描述了CART算法,給出了迴歸樹的算法描述,輔以簡單實例以加深理解,最後是總結部分。機器學習

2. 迴歸樹

決策樹其實是將空間用超平面進行劃分的一種方法,每次分割的時候,都將當前的空間一分爲二, 這樣使得每個葉子節點都是在空間中的一個不相交的區域,在進行決策的時候,會根據輸入樣本每一維feature的值,一步一步往下,最後使得樣本落入N個區域中的一個(假設有N個葉子節點),以下圖所示。ide

決策樹

三種比較常見的分類決策樹分支劃分方式包括:ID3, C4.5, CART。函數

分類決策樹

分類與迴歸樹(classificationandregressiontree, CART)模型由Breiman等人在1984年提出,是應用普遍的決策樹學習方法。CART一樣由特徵選擇、樹的生成及剪枝組成,既能夠用於分類也能夠用於迴歸。下面的英文引用自維基百科學習

The term Classification And Regression Tree (CART) analysis is an umbrella term used to refer to both of the above procedures, first introduced by Breiman et al. Trees used for regression and trees used for classification have some similarities - but also some differences, such as the procedure used to determine where to split.spa

下面介紹迴歸樹。

2.1 原理概述

既然是決策樹,那麼必然會存在如下兩個核心問題:如何選擇劃分點?如何決定葉節點的輸出值?

一個迴歸樹對應着輸入空間(即特徵空間)的一個劃分以及在劃分單元上的輸出值。分類樹中,咱們採用信息論中的方法,經過計算選擇最佳劃分點。而在迴歸樹中,採用的是啓發式的方法。假如咱們有n個特徵,每一個特徵有s_i(i \in (1,n))個取值,那咱們遍歷全部特徵,嘗試該特徵全部取值,對空間進行劃分,直到取到特徵j的取值s,使得損失函數最小,這樣就獲得了一個劃分點。描述該過程的公式以下:(若是看不到圖請點擊永久地址

損失

假設將輸入空間劃分爲M個單元:R_1,R_2,...,R_m 那麼每一個區域的輸出值就是:c_m=ave(y_i|x_i \in R_m)也就是該區域內全部點y值的平均數。

舉個例子。以下圖所示,假如咱們想要對樓內居民的年齡進行迴歸,將樓劃分爲3個區域R_1, R_2, R_3(紅線),那麼R_1的輸出就是第一列四個居民年齡的平均值,R_2的輸出就是第二列四個居民年齡的平均值,R_3的輸出就是第3、四列八個居民年齡的平均值。

一個例子

2.2 算法描述

截圖來自李航教授的統計學習方法

CART算法描述

2.3 一個簡單實例

爲了便於理解,下面舉一個簡單實例。訓練數據見下表,目標是獲得一棵最小二乘迴歸樹。

x 1 2 3 4 5 6 7 8 9 10
y 5.56 5.7 5.91 6.4 6.8 7.05 8.9 8.7 9 9.05
  1. 選擇最優切分變量j與最優切分點s

在本數據集中,只有一個變量,所以最優切分變量天然是x。

接下來咱們考慮9個切分點[1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5]你可能會問,爲何會帶小數點呢?類比於籃球比賽的博彩,假若兩隊比分是96:95,而盤口是「讓1分 A隊勝B隊」,那A隊讓1分以後,究竟是A隊贏仍是B隊贏了?因此咱們常常能夠看到「讓0.5分 A隊勝B隊」這樣的盤口。在這個實例中,也是這個道理。

損失函數定義爲平方損失函數 Loss(y, f(x))=(f(x)-y)^2,將上述9個切分點一依此代入下面的公式,其中 c_m=ave(y_i|x_i \in R_m) (若是看不到圖請點擊永久地址

損失

例如,取 s=1.5。此時 R_1=\{1\} , R_2=\{2,3,4,5,6,7,8,9,10\},這兩個區域的輸出值分別爲: c_1=5.56, c_2= \frac{1}{9}(5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05)=7.50。獲得下表:

s 1.5 2.5 3.5 4.5 5.5 6.5 7.5 8.5 9.5
c_1 5.56 5.63 5.72 5.89 6.07 6.24 6.62 6.88 7.11
c_2 7.5 7.73 7.99 8.25 8.54 8.91 8.92 9.03 9.05
把$c_1, c_2$的值代入到上式,如:$m(1.5)=0+15.72=15.72$。同理,可得到下表:
複製代碼
s 1.5 2.5 3.5 4.5 5.5 6.5 7.5 8.5 9.5
m(s) 15.72 12.07 8.36 5.78 3.91 1.93 8.01 11.73 15.74
顯然取 $s=6.5$時,$m(s)$最小。所以,第一個劃分變量$j=x, s=6.5$
複製代碼
  1. 用選定的(j,s)劃分區域,並決定輸出值

    兩個區域分別是:R_1=\{1,2,3,4,5,6\} , R_2=\{7,8,9,10\}輸出值c_m=ave(y_i|x_i \in R_m),c_1=6.24,c_2=8.91

  2. 對兩個子區域繼續調用步驟一、步驟2

    R_1繼續進行劃分:

    x 1 2 3 4 5 6
    y 5.56 5.7 5.91 6.4 6.8 7.05

    取切分點[1.5, 2.5, 3.5, 4.5, 5.5],則各區域的輸出值c以下表

    s 1.5 2.5 3.5 4.5 5.5
    c_1 5.56 5.63 5.72 5.89 6.07
    c_2 6.37 6.54 6.75 6.93 7.05

    計算m(s):

    s 1.5 2.5 3.5 4.5 5.5
    m(s) 1.3087 0.754 0.2771 0.4368 1.0644

    s=3.5時m(s)最小。

    以後的過程再也不贅述。

  3. 生成迴歸樹

    假設在生成3個區域以後中止劃分,那麼最終生成的迴歸樹形式以下:

    T=\left\{\begin{matrix}5.72 & x\leq 3.5\\ 6.75 &3.5\leqslant x\leq 6.5\\ 8.91 & 6.5<x\end{matrix}\right.

2.4 迴歸樹VS線性迴歸

很少說了,直接看圖甩代碼

迴歸樹VS線性迴歸

3. 總結

實際上,迴歸樹整體流程相似於分類樹,分枝時窮舉每個特徵的每個閾值,來尋找最優切分特徵j和最優切分點s,衡量的方法是平方偏差最小化。分枝直到達到預設的終止條件(如葉子個數上限)就中止。

固然,處理具體問題時,單一的迴歸樹確定是不夠用的。能夠利用集成學習中的boosting框架,對迴歸樹進行改良升級,獲得的新模型就是提高樹(Boosting Decision Tree),在進一步,能夠獲得梯度提高樹(Gradient Boosting Decision Tree,GBDT),再進一步能夠升級到XGBoost。


做者郵箱: mr.yxj@foxmail.com

轉載請告知做者,感謝!

相關文章
相關標籤/搜索