RandomForest隨機森林總結

1.隨機森林原理介紹html

隨機森林,指的是利用多棵樹對樣本進行訓練並預測的一種分類器。該分類器最先由Leo Breiman和Adele Cutler提出,並被註冊成了商標。簡單來講,隨機森林就是由多棵CART(Classification And Regression Tree)構成的。對於每棵樹,它們使用的訓練集是從總的訓練集中有放回採樣出來的,這意味着,總的訓練集中的有些樣本可能屢次出如今一棵樹的訓練集中,也可能從未出如今一棵樹的訓練集中。在訓練每棵樹的節點時,使用的特徵是從全部特徵中按照必定比例隨機地無放回的抽取的,根據Leo Breiman的建議,假設總的特徵數量爲M,這個比例能夠是sqrt(M),1/2sqrt(M),2sqrt(M)。node

所以,隨機森林的訓練過程能夠總結以下:git

(1)給定訓練集S,測試集T,特徵維數F。肯定參數:使用到的CART的數量t,每棵樹的深度d,每一個節點使用到的特徵數量f,終止條件:節點上最少樣本數s,節點上最少的信息增益mgithub

對於第1-t棵樹,i=1-t:算法

(2)從S中有放回的抽取大小和S同樣的訓練集S(i),做爲根節點的樣本,從根節點開始訓練數組

(3)若是當前節點上達到終止條件,則設置當前節點爲葉子節點,若是是分類問題,該葉子節點的預測輸出爲當前節點樣本集合中數量最多的那一類c(j),機率p爲c(j)佔當前樣本集的比例;若是是迴歸問題,預測輸出爲當前節點樣本集各個樣本值的平均值。而後繼續訓練其餘節點。若是當前節點沒有達到終止條件,則從F維特徵中無放回的隨機選取f維特徵。利用這f維特徵,尋找分類效果最好的一維特徵k及其閾值th,當前節點上樣本第k維特徵小於th的樣本被劃分到左節點,其他的被劃分到右節點。繼續訓練其餘節點。有關分類效果的評判標準在後面會講。dom

(4)重複(2)(3)直到全部節點都訓練過了或者被標記爲葉子節點。函數

(5)重複(2),(3),(4)直到全部CART都被訓練過。post

利用隨機森林的預測過程以下:學習

對於第1-t棵樹,i=1-t:

(1)從當前樹的根節點開始,根據當前節點的閾值th,判斷是進入左節點(<th)仍是進入右節點(>=th),直到到達,某個葉子節點,並輸出預測值。

(2)重複執行(1)直到全部t棵樹都輸出了預測值。若是是分類問題,則輸出爲全部樹中預測機率總和最大的那一個類,即對每一個c(j)的p進行累計;若是是迴歸問題,則輸出爲全部樹的輸出的平均值。

注:有關分類效果的評判標準,由於使用的是CART,所以使用的也是CART的評判標準,和C3.0,C4.5都不相同。

對於分類問題(將某個樣本劃分到某一類),也就是離散變量問題,CART使用Gini值做爲評判標準。定義爲Gini=1-∑(P(i)*P(i)),P(i)爲當前節點上數據集中第i類樣本的比例。例如:分爲2類,當前節點上有100個樣本,屬於第一類的樣本有70個,屬於第二類的樣本有30個,則Gini=1-0.7×07-0.3×03=0.42,能夠看出,類別分佈越平均,Gini值越大,類分佈越不均勻,Gini值越小。在尋找最佳的分類特徵和閾值時,評判標準爲:argmax(Gini-GiniLeft-GiniRight),即尋找最佳的特徵f和閾值th,使得當前節點的Gini值減去左子節點的Gini和右子節點的Gini值最大。

對於迴歸問題,相對更加簡單,直接使用argmax(Var-VarLeft-VarRight)做爲評判標準,即當前節點訓練集的方差Var減去減去左子節點的方差VarLeft和右子節點的方差VarRight值最大。

 

2.OpenCV函數使用

OpenCV提供了隨機森林的相關類和函數。具體使用方法以下:

(1)首先利用CvRTParams定義本身的參數,其格式以下

 

 CvRTParams::CvRTParams(int max_depth, int min_sample_count, float regression_accuracy, bool use_surrogates, int max_categories, const float* priors, bool calc_var_importance, int nactive_vars, int max_num_of_trees_in_the_forest, float forest_accuracy, int termcrit_type)

 

大部分參數描述都在http://docs.opencv.org/modules/ml/doc/random_trees.html上面有,說一下沒有描述的幾個參數的意義

bool use_surrogates:是否使用代理,指的是,若是當前的測試樣本缺乏某些特徵,可是在當前節點上的分類or迴歸特徵正是缺乏的這個特徵,那麼這個樣本就無法繼續沿着樹向下走了,達不到葉子節點的話,就沒有預測輸出,這種狀況下,能夠利用當前節點下面的全部子節點中的葉子節點預測輸出的平均值,做爲這個樣本的預測輸出。

const float*priors:先驗知識,這個指的是,能夠根據各個類別樣本數量的先驗分佈,對其進行加權。好比:若是一共有3類,第一類樣本佔整個訓練集的80%,其他兩類各佔10%,那麼這個數據集裏面的數據就很不平均,若是每類的樣本都加權的話,就算把全部樣本都預測成第一類,那麼準確率也有80%,這顯然是不合理的,所以咱們須要提升後兩類的權重,使得後兩類的分類正確率也不會過低。

float regression_accuracy:迴歸樹的終止條件,若是當前節點上全部樣本的真實值和預測值之間的差小於這個數值時,中止生產這個節點,並將其做爲葉子節點。

後來發現這些參數在決策樹裏面有解釋,英文說明在這裏http://docs.opencv.org/modules/ml/doc/decision_trees.html#cvdtreeparams

具體例子以下,網上找了個別人的例子,本身改爲了能夠讀取MNIST數據而且作分類的形式,以下:

 

#include <cv.h>       // opencv general include file
#include <ml.h>          // opencv machine learning include file
#include <stdio.h>

using namespace cv; // OpenCV API is in the C++ "cv" namespace

/******************************************************************************/
// global definitions (for speed and ease of use) //手寫體數字識別

#define NUMBER_OF_TRAINING_SAMPLES 60000
#define ATTRIBUTES_PER_SAMPLE 784
#define NUMBER_OF_TESTING_SAMPLES 10000

#define NUMBER_OF_CLASSES 10

// N.B. classes are integer handwritten digits in range 0-9

/******************************************************************************/

// loads the sample database from file (which is a CSV text file)
 inline void revertInt(int&x) { x=((x&0x000000ff)<<24)|((x&0x0000ff00)<<8)|((x&0x00ff0000)>>8)|((x&0xff000000)>>24); }; int read_data_from_csv(const char* samplePath,const char* labelPath, Mat data, Mat classes, int n_samples ) { FILE* sampleFile=fopen(samplePath,"rb"); FILE* labelFile=fopen(labelPath,"rb"); int mbs=0,number=0,col=0,row=0; fread(&mbs,4,1,sampleFile); fread(&number,4,1,sampleFile); fread(&row,4,1,sampleFile); fread(&col,4,1,sampleFile); revertInt(mbs); revertInt(number); revertInt(row); revertInt(col); fread(&mbs,4,1,labelFile); fread(&number,4,1,labelFile); revertInt(mbs); revertInt(number); unsigned char temp; for(int line = 0; line < n_samples; line++) { // for each attribute on the line in the file
        for(int attribute = 0; attribute < (ATTRIBUTES_PER_SAMPLE + 1); attribute++) { if (attribute < ATTRIBUTES_PER_SAMPLE) { // first 64 elements (0-63) in each line are the attributes
                fread(&temp,1,1,sampleFile); //fscanf(f, "%f,", &tmp);
                data.at<float>(line, attribute) = static_cast<float>(temp); // printf("%f,", data.at<float>(line, attribute));
 } else if (attribute == ATTRIBUTES_PER_SAMPLE) { // attribute 65 is the class label {0 ... 9}
                fread(&temp,1,1,labelFile); //fscanf(f, "%f,", &tmp);
                classes.at<float>(line, 0) = static_cast<float>(temp); // printf("%f\n", classes.at<float>(line, 0));
 } } } fclose(sampleFile); fclose(labelFile); return 1; // all OK
} /******************************************************************************/

int main( int argc, char** argv ) { for (int i=0; i< argc; i++) std::cout<<argv[i]<<std::endl; // lets just check the version first
    printf ("OpenCV version %s (%d.%d.%d)\n", CV_VERSION, CV_MAJOR_VERSION, CV_MINOR_VERSION, CV_SUBMINOR_VERSION); //定義訓練數據與標籤矩陣
    Mat training_data = Mat(NUMBER_OF_TRAINING_SAMPLES, ATTRIBUTES_PER_SAMPLE, CV_32FC1); Mat training_classifications = Mat(NUMBER_OF_TRAINING_SAMPLES, 1, CV_32FC1); //定義測試數據矩陣與標籤
    Mat testing_data = Mat(NUMBER_OF_TESTING_SAMPLES, ATTRIBUTES_PER_SAMPLE, CV_32FC1); Mat testing_classifications = Mat(NUMBER_OF_TESTING_SAMPLES, 1, CV_32FC1); // define all the attributes as numerical // alternatives are CV_VAR_CATEGORICAL or CV_VAR_ORDERED(=CV_VAR_NUMERICAL) // that can be assigned on a per attribute basis
 Mat var_type = Mat(ATTRIBUTES_PER_SAMPLE + 1, 1, CV_8U ); var_type.setTo(Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical // this is a classification problem (i.e. predict a discrete number of class // outputs) so reset the last (+1) output var_type element to CV_VAR_CATEGORICAL
 var_type.at<uchar>(ATTRIBUTES_PER_SAMPLE, 0) = CV_VAR_CATEGORICAL; double result; // value returned from a prediction //加載訓練數據集和測試數據集
    if (read_data_from_csv(argv[1],argv[2], training_data, training_classifications, NUMBER_OF_TRAINING_SAMPLES) && read_data_from_csv(argv[3],argv[4], testing_data, testing_classifications, NUMBER_OF_TESTING_SAMPLES)) { /********************************步驟1:定義初始化Random Trees的參數******************************/
        float priors[] = {1,1,1,1,1,1,1,1,1,1};  // weights of each classification for classes
        CvRTParams params = CvRTParams(20, // max depth
                                       50, // min sample count
                                       0, // regression accuracy: N/A here
                                       false, // compute surrogate split, no missing data
                                       15, // max number of categories (use sub-optimal algorithm for larger numbers)
                                       priors, // the array of priors
                                       false,  // calculate variable importance
                                       50,       // number of variables randomly selected at node and used to find the best split(s).
                                       100,     // max number of trees in the forest
                                       0.01f,                // forest accuracy
                                       CV_TERMCRIT_ITER |    CV_TERMCRIT_EPS // termination cirteria
 ); /****************************步驟2:訓練 Random Decision Forest(RDF)分類器*********************/ printf( "\nUsing training database: %s\n\n", argv[1]); CvRTrees* rtree = new CvRTrees; bool train_result=rtree->train(training_data, CV_ROW_SAMPLE, training_classifications, Mat(), Mat(), var_type, Mat(), params); // float train_error=rtree->get_train_error(); // printf("train error:%f\n",train_error); // perform classifier testing and report results
 Mat test_sample; int correct_class = 0; int wrong_class = 0; int false_positives [NUMBER_OF_CLASSES] = {0,0,0,0,0,0,0,0,0,0}; printf( "\nUsing testing database: %s\n\n", argv[2]); for (int tsample = 0; tsample < NUMBER_OF_TESTING_SAMPLES; tsample++) { // extract a row from the testing matrix
            test_sample = testing_data.row(tsample); /********************************步驟3:預測*********************************************/ result = rtree->predict(test_sample, Mat()); printf("Testing Sample %i -> class result (digit %d)\n", tsample, (int) result); // if the prediction and the (true) testing classification are the same // (N.B. openCV uses a floating point decision tree implementation!)
            if (fabs(result - testing_classifications.at<float>(tsample, 0)) >= FLT_EPSILON) { // if they differ more than floating point error => wrong class
                wrong_class++; false_positives[(int) result]++; } else { // otherwise correct
                correct_class++; } } printf( "\nResults on the testing database: %s\n"
                "\tCorrect classification: %d (%g%%)\n"
                "\tWrong classifications: %d (%g%%)\n", argv[2], correct_class, (double) correct_class*100/NUMBER_OF_TESTING_SAMPLES, wrong_class, (double) wrong_class*100/NUMBER_OF_TESTING_SAMPLES); for (int i = 0; i < NUMBER_OF_CLASSES; i++) { printf( "\tClass (digit %d) false postives %d (%g%%)\n", i, false_positives[i], (double) false_positives[i]*100/NUMBER_OF_TESTING_SAMPLES); } // all matrix memory free by destructors // all OK : main returns 0
        return 0; } // not OK : main returns -1
    return -1; }

 

MNIST樣本能夠在這個網址http://yann.lecun.com/exdb/mnist/下載,改一下路徑能夠直接跑的。

3.如何本身設計隨機森林程序

有時現有的庫沒法知足要求,就須要本身設計一個分類器算法,這部分講一下如何設計本身的隨機森林分類器,代碼實現就不貼了,由於在工做中用到了,所以比較敏感。

首先,要有一個RandomForest類,裏面保存整個樹須要的一些參數,包括但不限於:訓練樣本數量、測試樣本數量、特徵維數、每一個節點隨機提取的特徵維數、CART樹的數量、樹的最大深度、類別數量(若是是分類問題)、一些終止條件、指向全部樹的指針,指向訓練集和測試集的指針,指向訓練集label的指針等。還要有一些函數,至少要有train和predict吧。train裏面直接調用每棵樹的train方法便可,predict同理,但要對每棵樹的預測輸出作處理,獲得森林的預測輸出。

 

其次,要有一個sample類,這個類可不是用來存儲訓練集和對應label的,這是由於,每棵樹、每一個節點都有本身的樣本集和,若是你的存儲每一個樣本集和的話,須要的內存實在是太過巨大了,假設樣本數量爲M,特徵維數爲N,則整個訓練集大小爲M×N,而每棵樹的每層都有這麼多樣本,樹的深度爲D,共有S棵樹的話,則須要存儲M×N×D×S的存儲空間。這實在是太大了。所以,每一個節點訓練時用到的訓練樣本和特徵,咱們都用序號數組來代替,sample類就是幹這個的。sample的函數基本須要兩個就行,第一個是從現有訓練集有放回的隨機抽取一個新的訓練集,固然,只包含樣本的序號。第二個函數是從現有的特徵中無放回的隨機抽取必定數量的特徵,同理,也是特徵序號便可。

而後,須要一個Tree類,表明每棵樹,裏面保存樹的一些參數以及一個指向全部節點的指針。

最後,須要一個Node類,表明樹的每一個節點。

 

須要說明的是,保存樹的方式能夠是最普通的數組,也但是是vector。Node的保存方式同理,可是我的不建議用鏈表的方式,在程序設計以及函數處理上太麻煩,可是在省空間上並無太多的體現。

目前先寫這麼多,最後這部分我還會再擴充一些。

#2017.2.28

在github上開源了一個簡單的隨機森林程序,包含訓練、預測部分,支持分類和迴歸問題,裏面有mnist訓練的實例,附了很多註釋,比較適合入門學習,地址:

https://github.com/handspeaker/RandomForests

相關文章
相關標籤/搜索