算法實踐-隨機森林

這是一份隨機森林算法的python實踐代碼,若是你還不知道隨機森林算法是幹什麼用的可先參考《機器學習算法-隨機森林》python

代碼運行環境:git

python2.7github

擴展包依賴:算法

jieba==0.37
scikit-learn==0.17
# -*- coding: utf-8 -*-

import os
import gc

import jieba

from sklearn.externals import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import train_test_split
jieba.initialize()


class RandomForest(object):
    """ 隨機森林算法實踐 """
    def __init__(self, is_save=False):
        self.__is_save = is_save
        self.__clf = None
        self.__train_data_feature = None

        # 加載已保存的訓練集
        clf = RandomForestTools.train_data_load()
        train_data_feature = RandomForestTools.feature_data_load()
        if clf and train_data_feature:
            self.__clf = clf
            self.__train_data_feature = train_data_feature

    def build_train_data(self, pre_train_data_list, result_list, train_size=0.9):
        """
        構建訓練集
        :param pre_train_data_list: list 須要訓練的數據
        :param result_list list 訓練數據對應的結果
        :param train_size: float 0<train_size<=1 訓練集佔總數據的比例
        :return: object 訓練集
        """
        # 數據預處理
        print('Start pre-treat data.')
        train_data_list = []
        train_data_feature = set()
        for pre_train_data in pre_train_data_list:
            # 分詞
            train_data = self.word_segmentation(pre_train_data)
            train_data_list.append(train_data)
            # 提取分詞特徵
            for feature in train_data:
                train_data_feature.add(feature)

        # 數據預處理
        data = self.pre_treat_data(train_data_list, train_data_feature)

        # 將訓練集隨機分紅數份,以便自校驗訓練集準確率
        print('Start split train and test data.')
        data_train, data_test, result_train, result_test = train_test_split(data, result_list, train_size=train_size)

        # 開始訓練隨機森林,n_jobs設爲-1自動按內核數處理數據
        print('Start training random forest.')
        clf = RandomForestClassifier(n_jobs=-1)
        self.__clf = clf.fit(data_train, result_train)
        self.__train_data_feature = train_data_feature
        if self.__is_save:
            # 保存訓練集,及各項數據的特徵值
            print('Save training result.')
            RandomForestTools.train_data_save(self.__clf)
            RandomForestTools.feature_data_save(train_data_feature)
        print("Build train data finish and accuracy is:%.2f ." %
              (self.__clf.score(data_test, result_test)))

    @staticmethod
    def word_segmentation(train_data):
        """
        分詞處理
        :param train_data string 帶分詞數據
        :return set 分詞結果
        """
        word_segmentation_result = set()
        for word in jieba.lcut(train_data):
            word_segmentation_result.add(word)
        return word_segmentation_result

    @staticmethod
    def pre_treat_data(train_data_list, train_data_feature, is_gc_collect=False):
        """
        數據預處理,從已有數據中處理出最終訓練集數據
        :param train_data_list: list/set 待訓練集內容分詞
        :param train_data_feature set 待訓練集內容分詞特徵
        :param is_gc_collect: boolean 數據預處理完成後是否執行垃圾回收
        :return: list 通過預處理的待訓練數據
        """
        # 爲規避特徵保存及取出的過程當中順序打亂而形成數據不對應的狀況統一排序
        train_data_feature = sorted(train_data_feature)
        message_list = RandomForestTools.one_hot_encode_feature(train_data_list, train_data_feature)
        if is_gc_collect:
            print('Finish one hot encoder.')
            # 手動執行垃圾回收避免內存佔用太高被系統強制kill
            print('Garbage collector: collected %d objects.' % gc.collect())
        return message_list

    def predict(self, predict_data):
        """
        預測輸入數據是否爲壞樣本
        :param predict_data: string 待預測數據
        :return: 預測結果
        """
        predict_data = self.word_segmentation(predict_data)

        data_test = self.pre_treat_data([predict_data], self.__train_data_feature)
        result = self.__clf.predict(data_test)
        return result[0]


class RandomForestTools(object):
    """ 訓練集數據操做類 """
    TRAIN_DATA_FILE_DIR = '/tmp/'
    TRAIN_DATA_FILE = 'train_data.pkl'
    FEATURE_DATA_FILE = 'feature_data.pkl'

    @staticmethod
    def train_data_save(clf):
        """
        保存訓練集數據
        :param clf:訓練集
        :return: boolean True爲成功保存,False爲保存失敗
        """
        filename = RandomForestTools.TRAIN_DATA_FILE_DIR+RandomForestTools.TRAIN_DATA_FILE
        return RandomForestTools.save(filename, clf)

    @staticmethod
    def train_data_load():
        """
        加載已保存的訓練集數據
        :return: object/False 訓練集數據存在時返回訓練集,不存在時返回False
        """
        filename = RandomForestTools.TRAIN_DATA_FILE_DIR+RandomForestTools.TRAIN_DATA_FILE
        if RandomForestTools.exists(filename):
            return RandomForestTools.load(filename)
        else:
            return False

    @staticmethod
    def feature_data_save(feature_data):
        """
        保存特徵數據
        :param feature_data: set 訓練集的對應的特徵數據
        :return: boolean True爲成功保存,False爲保存失敗
        """
        filename = RandomForestTools.TRAIN_DATA_FILE_DIR + RandomForestTools.TRAIN_DATA_FILE
        return RandomForestTools.save(filename, feature_data)

    @staticmethod
    def feature_data_load():
        """
        加載已保存的特徵數據
        :return: object 特徵數據
        :notice: 加載前應主動檢測數據集是否存在,返回的特徵值順序可能會被打亂
        """
        filename = RandomForestTools.TRAIN_DATA_FILE_DIR + RandomForestTools.FEATURE_DATA_FILE
        if RandomForestTools.exists(filename):
            return RandomForestTools.load(filename)
        else:
            return False

    @staticmethod
    def save(filename, python_object):
        if not os.path.isdir(os.path.dirname(filename)):
            os.mkdir(os.path.dirname(filename))
        if joblib.dump(python_object, filename):
            return True
        else:
            return False

    @staticmethod
    def load(filename):
        return joblib.load(filename)

    @staticmethod
    def exists(filename):
        if os.path.isfile(filename):
            return True
        else:
            return False

    @staticmethod
    def one_hot_encode_feature(data_list, data_set):
        """
        根據特徵在內容中是否出現將數據格式化成二維二進制數組
        :param data_list: list/set 待格式化數據
        :param data_set: set 特徵統計
        :return: list 二維數組
        """
        x, y = 0, 0
        serialize_list = []
        for data in data_list:
            tmp_serialize_list = []
            for key in data_set:
                # 分詞以list的方式判斷是否存在特徵是否存在
                if isinstance(data, list) or isinstance(data, set):
                    tmp_serialize_list.append(1 if key in data else 0)
                elif isinstance(data, basestring) or isinstance(data, int):
                    tmp_serialize_list.append(1 if key == data else 0)
                y += 1
            serialize_list.append(tmp_serialize_list)
            x += 1
        return serialize_list


class RandomForestException(Exception):
    pass

if __name__ == '__main__':
    pre_train_data = [
        u'我很開心',
        u'我很是開心',
        u'我其實很開心',
        u'我特別開心',
        u'我超級開心',
        u'我不開心',
        u'我一點都不開心',
        u'我很不開心',
        u'我很是不開心',
        u'我好久沒那麼不開心了',
    ]
    result_list = [
        u'開心',
        u'開心',
        u'開心',
        u'開心',
        u'開心',
        u'不開心',
        u'不開心',
        u'不開心',
        u'不開心',
        u'不開心',
    ]
    rf = RandomForest()
    rf.build_train_data(pre_train_data,result_list)
    print(rf.predict(
        u'你猜我開心嗎?',
    ))
 

你也能夠在 Github上找到它。數組

 

源地址 By佐柱app

轉載請註明出處,也歡迎偶爾逛逛個人小站,謝謝 :)dom

相關文章
相關標籤/搜索