機器學習實戰(二)決策樹

在構造決策樹時,咱們須要解決的第一個問題就是,當前數據集上哪一個特徵在劃分數據分類時起決定性做用。爲了找到決定性特徵,劃分出最好的結果,咱們必須評估每一個特徵。完成測試以後,原始數據集就被劃分爲幾個數據子集。這些數據子集會分佈在第一個決策點的全部分支上。若是某個分支下的數據屬於同一類型,則此分支無需繼續劃分。若是數據子集內的數據不屬於同一類,則須要重複劃分數據子集。劃分數據子集的算法和劃分原始數據集的方法相同,直到所具備相同數據類型的數據均在一個數據子集內。html

1. Tree construction

建立分支的僞代碼函數 createBranch() 以下:html5

Check if every item in the dataset is in the same class:
    If so return the class label
    Else
        find the best feature to split the data
        split the dataset
        create a branch node
        for each split
            call createBranch and add the result to the branch node
    return branch node                

劃分數據集的標準有信息增益(ID3)、信息增益比(C4.5)、基尼指數(CART).node

1.1 Information gain

訓練集:$D={(x_1,y_1), (x_2,y_2), ..., (x_m, y_m)}$python

屬性集:$A={a_1, a_2, ..., a_d}$linux

def createDataSet():
    dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
    feature = ['no surfacing', 'flippers']
    return dataSet, feature

$a_{\star} = \arg\max\limits_{a\in{A}} Gain(D, a)$web

劃分數據集的大原則是:將無序的數據變得更加有序。在劃分數據集先後信息發生的變化稱爲信息增益,咱們計算每一個特徵值劃分數據集得到信息增益,得到的信息增益最高的特徵就是最好的劃分特徵。集合信息的度量方式稱爲熵,假定當前樣本集合$D$中第$k$類樣本所佔比例爲 $p_k \ (k=1,2,...,\lvert{y}\rvert)$,則$D$的information entropy是:算法

$Ent(D) = - \sum_{k=1}^{\lvert{y}\rvert}\  p_k log_2^{p_k}$shell

那麼對於$D$的各個結點$D_v$,咱們能夠算出$D_v$的information entropy,再考慮到不一樣的分支結點所包含的樣本數不均勻,給分支賦予權重$\frac{\lvert{D_v}\rvert}{\lvert{D}\rvert}$,這樣獲得information gain:api

$Gain(D,a_{\star}) = Ent(D) - \sum_{v=1}^{V} \frac{\lvert{D_v}\rvert}{\lvert{D}\rvert}Ent(D_v)$數據結構

# 1. 計算信息熵
from math import log
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]   # 讀取當前樣本標籤
        if currentLabel not in labelCounts.keys(): # 檢查字典中是否存在該標籤
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries # 計算類別機率
        shannonEnt -= prob * log(prob,2) # 計算信息熵
    return shannonEnt

信息熵越高,表明數據越複雜,不純度越高。

1.2 Splitting the dataset

# 2. 劃分數據集,參數: 數據集, 劃分屬性, 子節點屬性值
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:]) # extend 按元素添加進去
            retDataSet.append(reducedFeatVec)      # append 按總體添加進去
    return retDataSet

接下來咱們要遍歷整個數據集,循環計算信息熵和splitDataSet()函數,找到最好的劃分屬性(取信息增益最大值)。

# 3. 選擇最好的數據集劃分方式
def chooseBestFeatureToSplit(dataSet):
    numFeayures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet) # 根結點信息熵
    bestInfoGain = 0.0; bestFeature = -1  # 初始化信息增益和最優分裂屬性
    for i in range(numFeayures): # 對於每一個屬性,計算信息增益,得到最大信息增益對應的分裂屬性
        featList = [example[i] for example in dataSet] # 收集全部樣本在該屬性上的值
        uniqueVals = set(featList) # 集合統計共有幾種屬性值
        newEntropy = 0.0
        for value in uniqueVals: # 對於分裂屬性的每一個取值的可能,劃分子集計算信息熵
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

在函數中調用的數據須要知足必定的要求:第一個要求是,數據必須是一種由列表元素組成的列表,並且全部的元素都要具備相同的數據長度;第二個要求是,數據的最後一列或者每一個實例的最後一個元素是當前實例的類別標籤。數據集一旦知足上述要求,咱們就能夠在函數的第一行斷定當前數據集包含多數特徵屬性。咱們無需限定list中數據的類型,它們既能夠是數字也能夠是字符串,並不影響實際計算。

1.3 Recursively building the tree

以上,咱們咱們首先獲得元素數據集,而後基於最好的屬性值劃分數據集,因爲特徵值可能多於兩個,所以可能存在大於兩個分支的數據集劃分。第一次劃分後,數據被向下傳遞到樹分支的下一個節點,在這個節點上,咱們能夠再次劃分數據。所以咱們能夠採用遞歸的原則處理數據集。

遞歸結束的條件是:程序遍歷到某個子節點時已經沒有屬性來劃分了,該節點記爲葉子節點,這時咱們以投票的方式決定該節點類別。若是全部實例具備相同的分類,該節點記爲葉子節點,任何到達葉子節點的數據必然屬於葉子節點的分類。

# 4. 採用多數表決的方法決定葉子節點的類別
import operator
def majorityCnt(classList):
    classCount = {}
    # 類別以字典形式統計
    for vote in classList:
        if vote not in classCount.keys():classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) # 排序字典
    return  sortedClassCount[0][0]

 

# 5. 建立決策樹代碼,輸入數據集和特徵/屬性類型,輸出字典型模型
def createTree(dataSet, features):
    classList = [example[-1] for example in dataSet] # 得到當前數據全部類別if classList.count(classList[0]) == len(classList): # 若是類別徹底相同則中止劃分,返回該類別
        return classList[0]
    if len(dataSet[0]) == 1: # 若是劃分屬性已經所有用完,中止劃分,返回該節點類別最多的類
        return majorityCnt(classList)
    bestFeatureIndex = chooseBestFeatureToSplit(dataSet) # 選擇最優屬性劃分
    bestFeature = features[bestFeatureIndex]
    myTree = {bestFeature:{}} # 構建模型字典,字典的key是最優屬性,字典的value預約義爲一個字典
    featValues = [example[bestFeatureIndex] for example in dataSet] # 收集全部樣本在該屬性上的值
    uniqueVals = set(featValues) # 集合統計共有幾種屬性值,對應幾個子集
    for value in uniqueVals:
        subFeatures = features[:] # 爲了避免改變features列表變量,咱們使用新變量代替
        del (subFeatures[bestFeatureIndex])  # 刪除已用劃分屬性
        myTree[bestFeature][value] = createTree(splitDataSet(dataSet, bestFeatureIndex, value), subFeatures) # 繼續分裂,若是沒法分裂返回類別
    return myTree

上述代碼首先建立了名爲classList的列表變量,包含了數據集的全部類別標籤。遞歸函數的第一個中止條件是全部的類別標籤徹底相同,則返回該類標籤。遞歸函數的第二個中止條件是使用完了全部的特徵,仍然不能將數據集劃分紅僅包含惟一類別的分組,這裏使用投票方式得到類別。

下一步程序開始建立數,這裏python使用字典類型存儲樹的信息。字典變量存儲了數的全部信息,這對於隨後的繪製數形圖很是重要。由於在python中,函數參數是列表類型時,參數是按照引用形式傳遞的,爲了避免改變原始類別標籤,程序中使用 subFeatures = features[:] 複製了類別標籤。

myData, feature = createDataSet()
myTree = createTree(myData, feature)
print myTree

輸出:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

變量myTree包含了不少表明樹結構信息的嵌套字典,從左邊開始,第一個關鍵字 no surfacing 是第一個劃分屬性,該關鍵字值也是另外一個數據字典。後面這個數據字典是根據 no surfacing特徵劃分的數據子集,這些關鍵字表明的是各個分支,關鍵字值多是類標籤,也多是另外一個數據字典。若是值是類標籤,則該節點爲葉子節點;若是值是另外一個數據字典,則該節點是一個判斷節點,這種格式不斷重複構成了整棵樹。

 2. Pyhton中使用Matplotlib註解繪製樹形圖

對於matplotlib不支持中文狀況,修改Lib\site-packages\matplotlib\mpl-data\matplotlibrc

### MATPLOTLIBRC FORMAT

# This is a sample matplotlib configuration file - you can find a copy
# of it on your system in
# site-packages/matplotlib/mpl-data/matplotlibrc.  If you edit it
# there, please note that it will be overwritten in your next install.
# If you want to keep a permanent local copy that will not be
# overwritten, place it in the following location:
# unix/linux:
#     $HOME/.config/matplotlib/matplotlibrc or
#     $XDG_CONFIG_HOME/matplotlib/matplotlibrc (if $XDG_CONFIG_HOME is set)
# other platforms:
#     $HOME/.matplotlib/matplotlibrc
#
# See http://matplotlib.org/users/customizing.html#the-matplotlibrc-file for
# more details on the paths which are checked for the configuration file.
#
# This file is best viewed in a editor which supports python mode
# syntax highlighting. Blank lines, or lines starting with a comment
# symbol, are ignored, as are trailing comments.  Other lines must
# have the format
#    key : val # optional comment
#
# Colors: for the color values below, you can either use - a
# matplotlib color string, such as r, k, or b - an rgb tuple, such as
# (1.0, 0.5, 0.0) - a hex string, such as ff00ff or #ff00ff - a scalar
# grayscale intensity such as 0.75 - a legal html color name, e.g., red,
# blue, darkslategray

#### CONFIGURATION BEGINS HERE

# The default backend; one of GTK GTKAgg GTKCairo GTK3Agg GTK3Cairo
# CocoaAgg MacOSX Qt4Agg Qt5Agg TkAgg WX WXAgg Agg Cairo GDK PS PDF SVG
# Template.
# You can also deploy your own backend outside of matplotlib by
# referring to the module name (which must be in the PYTHONPATH) as
# 'module://my_backend'.
backend      : Qt5Agg

# If you are using the Qt4Agg backend, you can choose here
# to use the PyQt4 bindings or the newer PySide bindings to
# the underlying Qt4 toolkit.
backend.qt5 : PyQt5

# Note that this can be overridden by the environment variable
# QT_API used by Enthought Tool Suite (ETS); valid values are
# "pyqt" and "pyside".  The "pyqt" setting has the side effect of
# forcing the use of Version 2 API for QString and QVariant.

# The port to use for the web server in the WebAgg backend.
# webagg.port : 8888

# If webagg.port is unavailable, a number of other random ports will
# be tried until one that is available is found.
# webagg.port_retries : 50

# When True, open the webbrowser to the plot that is shown
# webagg.open_in_browser : True

# When True, the figures rendered in the nbagg backend are created with
# a transparent background.
# nbagg.transparent : True

# if you are running pyplot inside a GUI and your backend choice
# conflicts, we will automatically try to find a compatible one for
# you if backend_fallback is True
#backend_fallback: True

#interactive  : False
#toolbar      : toolbar2   # None | toolbar2  ("classic" is deprecated)
#timezone     : UTC        # a pytz timezone string, e.g., US/Central or Europe/Paris

# Where your matplotlib data lives if you installed to a non-default
# location.  This is where the matplotlib fonts, bitmaps, etc reside
#datapath : /home/jdhunter/mpldata


### LINES
# See http://matplotlib.org/api/artist_api.html#module-matplotlib.lines for more
# information on line properties.
#lines.linewidth   : 1.0     # line width in points
#lines.linestyle   : -       # solid line
#lines.color       : blue    # has no affect on plot(); see axes.prop_cycle
#lines.marker      : None    # the default marker
#lines.markeredgewidth  : 0.5     # the line width around the marker symbol
#lines.markersize  : 6            # markersize, in points
#lines.dash_joinstyle : miter        # miter|round|bevel
#lines.dash_capstyle : butt          # butt|round|projecting
#lines.solid_joinstyle : miter       # miter|round|bevel
#lines.solid_capstyle : projecting   # butt|round|projecting
#lines.antialiased : True         # render lines in antialiased (no jaggies)

#markers.fillstyle: full # full|left|right|bottom|top|none

### PATCHES
# Patches are graphical objects that fill 2D space, like polygons or
# circles.  See
# http://matplotlib.org/api/artist_api.html#module-matplotlib.patches
# information on patch properties
#patch.linewidth        : 1.0     # edge width in points
#patch.facecolor        : blue
#patch.edgecolor        : black
#patch.antialiased      : True    # render patches in antialiased (no jaggies)

### FONT
#
# font properties used by text.Text.  See
# http://matplotlib.org/api/font_manager_api.html for more
# information on font properties.  The 6 font properties used for font
# matching are given below with their default values.
#
# The font.family property has five values: 'serif' (e.g., Times),
# 'sans-serif' (e.g., Helvetica), 'cursive' (e.g., Zapf-Chancery),
# 'fantasy' (e.g., Western), and 'monospace' (e.g., Courier).  Each of
# these font families has a default list of font names in decreasing
# order of priority associated with them.  When text.usetex is False,
# font.family may also be one or more concrete font names.
#
# The font.style property has three values: normal (or roman), italic
# or oblique.  The oblique style will be used for italic, if it is not
# present.
#
# The font.variant property has two values: normal or small-caps.  For
# TrueType fonts, which are scalable fonts, small-caps is equivalent
# to using a font size of 'smaller', or about 83% of the current font
# size.
#
# The font.weight property has effectively 13 values: normal, bold,
# bolder, lighter, 100, 200, 300, ..., 900.  Normal is the same as
# 400, and bold is 700.  bolder and lighter are relative values with
# respect to the current weight.
#
# The font.stretch property has 11 values: ultra-condensed,
# extra-condensed, condensed, semi-condensed, normal, semi-expanded,
# expanded, extra-expanded, ultra-expanded, wider, and narrower.  This
# property is not currently implemented.
#
# The font.size property is the default font size for text, given in pts.
# 12pt is the standard value.
#
font.family         : sans-serif
#font.style          : normal
#font.variant        : normal
#font.weight         : medium
#font.stretch        : normal
# note that font.size controls default text sizes.  To configure
# special text sizes tick labels, axes, labels, title, etc, see the rc
# settings for axes and ticks. Special text sizes can be defined
# relative to font.size, using the following values: xx-small, x-small,
# small, medium, large, x-large, xx-large, larger, or smaller
#font.size           : 12.0
#font.serif          : Bitstream Vera Serif, New Century Schoolbook, Century Schoolbook L, Utopia, ITC Bookman, Bookman, Nimbus Roman No9 L, Times New Roman, Times, Palatino, Charter, serif
font.sans-serif     : SimHei, Bitstream Vera Sans, Lucida Grande, Verdana, Geneva, Lucid, Arial, Helvetica, Avant Garde, sans-serif
#font.cursive        : Apple Chancery, Textile, Zapf Chancery, Sand, Script MT, Felipa, cursive
#font.fantasy        : Comic Sans MS, Chicago, Charcoal, Impact, Western, Humor Sans, fantasy
#font.monospace      : Bitstream Vera Sans Mono, Andale Mono, Nimbus Mono L, Courier New, Courier, Fixed, Terminal, monospace

### TEXT
# text properties used by text.Text.  See
# http://matplotlib.org/api/artist_api.html#module-matplotlib.text for more
# information on text properties

#text.color          : black

### LaTeX customizations. See http://wiki.scipy.org/Cookbook/Matplotlib/UsingTex
#text.usetex         : False  # use latex for all text handling. The following fonts
                              # are supported through the usual rc parameter settings:
                              # new century schoolbook, bookman, times, palatino,
                              # zapf chancery, charter, serif, sans-serif, helvetica,
                              # avant garde, courier, monospace, computer modern roman,
                              # computer modern sans serif, computer modern typewriter
                              # If another font is desired which can loaded using the
                              # LaTeX \usepackage command, please inquire at the
                              # matplotlib mailing list
#text.latex.unicode : False # use "ucs" and "inputenc" LaTeX packages for handling
                            # unicode strings.
#text.latex.preamble :  # IMPROPER USE OF THIS FEATURE WILL LEAD TO LATEX FAILURES
                            # AND IS THEREFORE UNSUPPORTED. PLEASE DO NOT ASK FOR HELP
                            # IF THIS FEATURE DOES NOT DO WHAT YOU EXPECT IT TO.
                            # preamble is a comma separated list of LaTeX statements
                            # that are included in the LaTeX document preamble.
                            # An example:
                            # text.latex.preamble : \usepackage{bm},\usepackage{euler}
                            # The following packages are always loaded with usetex, so
                            # beware of package collisions: color, geometry, graphicx,
                            # type1cm, textcomp. Adobe Postscript (PSSNFS) font packages
                            # may also be loaded, depending on your font settings

#text.dvipnghack : None      # some versions of dvipng don't handle alpha
                             # channel properly.  Use True to correct
                             # and flush ~/.matplotlib/tex.cache
                             # before testing and False to force
                             # correction off.  None will try and
                             # guess based on your dvipng version

#text.hinting : auto   # May be one of the following:
                       #   'none': Perform no hinting
                       #   'auto': Use freetype's autohinter
                       #   'native': Use the hinting information in the
                       #             font file, if available, and if your
                       #             freetype library supports it
                       #   'either': Use the native hinting information,
                       #             or the autohinter if none is available.
                       # For backward compatibility, this value may also be
                       # True === 'auto' or False === 'none'.
#text.hinting_factor : 8 # Specifies the amount of softness for hinting in the
                         # horizontal direction.  A value of 1 will hint to full
                         # pixels.  A value of 2 will hint to half pixels etc.

#text.antialiased : True # If True (default), the text will be antialiased.
                         # This only affects the Agg backend.

# The following settings allow you to select the fonts in math mode.
# They map from a TeX font name to a fontconfig font pattern.
# These settings are only used if mathtext.fontset is 'custom'.
# Note that this "custom" mode is unsupported and may go away in the
# future.
#mathtext.cal : cursive
#mathtext.rm  : serif
#mathtext.tt  : monospace
#mathtext.it  : serif:italic
#mathtext.bf  : serif:bold
#mathtext.sf  : sans
#mathtext.fontset : cm # Should be 'cm' (Computer Modern), 'stix',
                       # 'stixsans' or 'custom'
#mathtext.fallback_to_cm : True  # When True, use symbols from the Computer Modern
                                 # fonts when a symbol can not be found in one of
                                 # the custom math fonts.

#mathtext.default : it # The default font to use for math.
                       # Can be any of the LaTeX font names, including
                       # the special name "regular" for the same font
                       # used in regular text.

### AXES
# default face and edge color, default tick sizes,
# default fontsizes for ticklabels, and so on.  See
# http://matplotlib.org/api/axes_api.html#module-matplotlib.axes
#axes.hold           : True    # whether to clear the axes by default on
#axes.facecolor      : white   # axes background color
#axes.edgecolor      : black   # axes edge color
#axes.linewidth      : 1.0     # edge linewidth
#axes.grid           : False   # display grid or not
#axes.titlesize      : large   # fontsize of the axes title
#axes.labelsize      : medium  # fontsize of the x any y labels
#axes.labelpad       : 5.0     # space between label and axis
#axes.labelweight    : normal  # weight of the x and y labels
#axes.labelcolor     : black
#axes.axisbelow      : False   # whether axis gridlines and ticks are below
                               # the axes elements (lines, text, etc)

#axes.formatter.limits : -7, 7 # use scientific notation if log10
                               # of the axis range is smaller than the
                               # first or larger than the second
#axes.formatter.use_locale : False # When True, format tick labels
                                   # according to the user's locale.
                                   # For example, use ',' as a decimal
                                   # separator in the fr_FR locale.
#axes.formatter.use_mathtext : False # When True, use mathtext for scientific
                                     # notation.
#axes.formatter.useoffset      : True    # If True, the tick label formatter
                                         # will default to labeling ticks relative
                                         # to an offset when the data range is very
                                         # small compared to the minimum absolute
                                         # value of the data.

#axes.unicode_minus  : True    # use unicode for the minus symbol
                               # rather than hyphen.  See
                               # http://en.wikipedia.org/wiki/Plus_and_minus_signs#Character_codes
#axes.prop_cycle    : cycler('color', 'bgrcmyk')
                                            # color cycle for plot lines
                                            # as list of string colorspecs:
                                            # single letter, long name, or
                                            # web-style hex
#axes.xmargin        : 0  # x margin.  See `axes.Axes.margins`
#axes.ymargin        : 0  # y margin See `axes.Axes.margins`

#polaraxes.grid      : True    # display grid on polar axes
#axes3d.grid         : True    # display grid on 3d axes

### TICKS
# see http://matplotlib.org/api/axis_api.html#matplotlib.axis.Tick
#xtick.major.size     : 4      # major tick size in points
#xtick.minor.size     : 2      # minor tick size in points
#xtick.major.width    : 0.5    # major tick width in points
#xtick.minor.width    : 0.5    # minor tick width in points
#xtick.major.pad      : 4      # distance to major tick label in points
#xtick.minor.pad      : 4      # distance to the minor tick label in points
#xtick.color          : k      # color of the tick labels
#xtick.labelsize      : medium # fontsize of the tick labels
#xtick.direction      : in     # direction: in, out, or inout

#ytick.major.size     : 4      # major tick size in points
#ytick.minor.size     : 2      # minor tick size in points
#ytick.major.width    : 0.5    # major tick width in points
#ytick.minor.width    : 0.5    # minor tick width in points
#ytick.major.pad      : 4      # distance to major tick label in points
#ytick.minor.pad      : 4      # distance to the minor tick label in points
#ytick.color          : k      # color of the tick labels
#ytick.labelsize      : medium # fontsize of the tick labels
#ytick.direction      : in     # direction: in, out, or inout


### GRIDS
#grid.color       :   black   # grid color
#grid.linestyle   :   :       # dotted
#grid.linewidth   :   0.5     # in points
#grid.alpha       :   1.0     # transparency, between 0.0 and 1.0

### Legend
#legend.fancybox      : False  # if True, use a rounded box for the
                               # legend, else a rectangle
#legend.isaxes        : True
#legend.numpoints     : 2      # the number of points in the legend line
#legend.fontsize      : large
#legend.borderpad     : 0.5    # border whitespace in fontsize units
#legend.markerscale   : 1.0    # the relative size of legend markers vs. original
# the following dimensions are in axes coords
#legend.labelspacing  : 0.5    # the vertical space between the legend entries in fraction of fontsize
#legend.handlelength  : 2.     # the length of the legend lines in fraction of fontsize
#legend.handleheight  : 0.7     # the height of the legend handle in fraction of fontsize
#legend.handletextpad : 0.8    # the space between the legend line and legend text in fraction of fontsize
#legend.borderaxespad : 0.5   # the border between the axes and legend edge in fraction of fontsize
#legend.columnspacing : 2.    # the border between the axes and legend edge in fraction of fontsize
#legend.shadow        : False
#legend.frameon       : True   # whether or not to draw a frame around legend
#legend.framealpha    : None    # opacity of of legend frame
#legend.scatterpoints : 3 # number of scatter points

### FIGURE
# See http://matplotlib.org/api/figure_api.html#matplotlib.figure.Figure
#figure.titlesize : medium     # size of the figure title
#figure.titleweight : normal   # weight of the figure title
#figure.figsize   : 8, 6    # figure size in inches
#figure.dpi       : 80      # figure dots per inch
#figure.facecolor : 0.75    # figure facecolor; 0.75 is scalar gray
#figure.edgecolor : white   # figure edgecolor
#figure.autolayout : False  # When True, automatically adjust subplot
                            # parameters to make the plot fit the figure
#figure.max_open_warning : 20  # The maximum number of figures to open through
                               # the pyplot interface before emitting a warning.
                               # If less than one this feature is disabled.

# The figure subplot parameters.  All dimensions are a fraction of the
# figure width or height
#figure.subplot.left    : 0.125  # the left side of the subplots of the figure
#figure.subplot.right   : 0.9    # the right side of the subplots of the figure
#figure.subplot.bottom  : 0.1    # the bottom of the subplots of the figure
#figure.subplot.top     : 0.9    # the top of the subplots of the figure
#figure.subplot.wspace  : 0.2    # the amount of width reserved for blank space between subplots
#figure.subplot.hspace  : 0.2    # the amount of height reserved for white space between subplots

### IMAGES
#image.aspect : equal             # equal | auto | a number
#image.interpolation  : bilinear  # see help(imshow) for options
#image.cmap   : jet               # gray | jet etc...
#image.lut    : 256               # the size of the colormap lookup table
#image.origin : upper             # lower | upper
#image.resample  : False
#image.composite_image : True     # When True, all the images on a set of axes are 
                                  # combined into a single composite image before 
                                  # saving a figure as a vector graphics file, 
                                  # such as a PDF.

### CONTOUR PLOTS
#contour.negative_linestyle : dashed # dashed | solid
#contour.corner_mask        : True   # True | False | legacy

### ERRORBAR PLOTS
#errorbar.capsize : 3             # length of end cap on error bars in pixels

### Agg rendering
### Warning: experimental, 2008/10/10
#agg.path.chunksize : 0           # 0 to disable; values in the range
                                  # 10000 to 100000 can improve speed slightly
                                  # and prevent an Agg rendering failure
                                  # when plotting very large data sets,
                                  # especially if they are very gappy.
                                  # It may cause minor artifacts, though.
                                  # A value of 20000 is probably a good
                                  # starting point.
### SAVING FIGURES
#path.simplify : True   # When True, simplify paths by removing "invisible"
                        # points to reduce file size and increase rendering
                        # speed
#path.simplify_threshold : 0.1  # The threshold of similarity below which
                                # vertices will be removed in the simplification
                                # process
#path.snap : True # When True, rectilinear axis-aligned paths will be snapped to
                  # the nearest pixel when certain criteria are met.  When False,
                  # paths will never be snapped.
#path.sketch : None # May be none, or a 3-tuple of the form (scale, length,
                    # randomness).
                    # *scale* is the amplitude of the wiggle
                    # perpendicular to the line (in pixels).  *length*
                    # is the length of the wiggle along the line (in
                    # pixels).  *randomness* is the factor by which
                    # the length is randomly scaled.

# the default savefig params can be different from the display params
# e.g., you may want a higher resolution, or to make the figure
# background white
#savefig.dpi         : 100      # figure dots per inch
#savefig.facecolor   : white    # figure facecolor when saving
#savefig.edgecolor   : white    # figure edgecolor when saving
#savefig.format      : png      # png, ps, pdf, svg
#savefig.bbox        : standard # 'tight' or 'standard'.
                                # 'tight' is incompatible with pipe-based animation
                                # backends but will workd with temporary file based ones:
                                # e.g. setting animation.writer to ffmpeg will not work,
                                # use ffmpeg_file instead
#savefig.pad_inches  : 0.1      # Padding to be used when bbox is set to 'tight'
#savefig.jpeg_quality: 95       # when a jpeg is saved, the default quality parameter.
#savefig.directory   : ~        # default directory in savefig dialog box,
                                # leave empty to always use current working directory
#savefig.transparent : False    # setting that controls whether figures are saved with a
                                # transparent background by default

# tk backend params
#tk.window_focus   : False    # Maintain shell focus for TkAgg

# ps backend params
#ps.papersize      : letter   # auto, letter, legal, ledger, A0-A10, B0-B10
#ps.useafm         : False    # use of afm fonts, results in small files
#ps.usedistiller   : False    # can be: None, ghostscript or xpdf
                                          # Experimental: may produce smaller files.
                                          # xpdf intended for production of publication quality files,
                                          # but requires ghostscript, xpdf and ps2eps
#ps.distiller.res  : 6000      # dpi
#ps.fonttype       : 3         # Output Type 3 (Type3) or Type 42 (TrueType)

# pdf backend params
#pdf.compression   : 6 # integer from 0 to 9
                       # 0 disables compression (good for debugging)
#pdf.fonttype       : 3         # Output Type 3 (Type3) or Type 42 (TrueType)

# svg backend params
#svg.image_inline : True       # write raster image data directly into the svg file
#svg.image_noscale : False     # suppress scaling of raster data embedded in SVG
#svg.fonttype : 'path'         # How to handle SVG fonts:
#    'none': Assume fonts are installed on the machine where the SVG will be viewed.
#    'path': Embed characters as paths -- supported by most SVG renderers
#    'svgfont': Embed characters as SVG fonts -- supported only by Chrome,
#               Opera and Safari

# docstring params
#docstring.hardcopy = False  # set this when you want to generate hardcopy docstring

# Set the verbose flags.  This controls how much information
# matplotlib gives you at runtime and where it goes.  The verbosity
# levels are: silent, helpful, debug, debug-annoying.  Any level is
# inclusive of all the levels below it.  If your setting is "debug",
# you'll get all the debug and helpful messages.  When submitting
# problems to the mailing-list, please set verbose to "helpful" or "debug"
# and paste the output into your report.
#
# The "fileo" gives the destination for any calls to verbose.report.
# These objects can a filename, or a filehandle like sys.stdout.
#
# You can override the rc default verbosity from the command line by
# giving the flags --verbose-LEVEL where LEVEL is one of the legal
# levels, e.g., --verbose-helpful.
#
# You can access the verbose instance in your code
#   from matplotlib import verbose.
#verbose.level  : silent      # one of silent, helpful, debug, debug-annoying
#verbose.fileo  : sys.stdout  # a log filename, sys.stdout or sys.stderr

# Event keys to interact with figures/plots via keyboard.
# Customize these settings according to your needs.
# Leave the field(s) empty if you don't need a key-map. (i.e., fullscreen : '')

#keymap.fullscreen : f               # toggling
#keymap.home : h, r, home            # home or reset mnemonic
#keymap.back : left, c, backspace    # forward / backward keys to enable
#keymap.forward : right, v           #   left handed quick navigation
#keymap.pan : p                      # pan mnemonic
#keymap.zoom : o                     # zoom mnemonic
#keymap.save : s                     # saving current figure
#keymap.quit : ctrl+w, cmd+w         # close the current figure
#keymap.grid : g                     # switching on/off a grid in current axes
#keymap.yscale : l                   # toggle scaling of y-axes ('log'/'linear')
#keymap.xscale : L, k                # toggle scaling of x-axes ('log'/'linear')
#keymap.all_axes : a                 # enable all axes

# Control location of examples data files
#examples.directory : ''   # directory to look in for custom installation

###ANIMATION settings
#animation.html : 'none'           # How to display the animation as HTML in
                                   # the IPython notebook. 'html5' uses
                                   # HTML5 video tag.
#animation.writer : ffmpeg         # MovieWriter 'backend' to use
#animation.codec : mpeg4           # Codec to use for writing movie
#animation.bitrate: -1             # Controls size/quality tradeoff for movie.
                                   # -1 implies let utility auto-determine
#animation.frame_format: 'png'     # Controls frame format used by temp files
#animation.ffmpeg_path: 'ffmpeg'   # Path to ffmpeg binary. Without full path
                                   # $PATH is searched
#animation.ffmpeg_args: ''         # Additional arguments to pass to ffmpeg
#animation.avconv_path: 'avconv'   # Path to avconv binary. Without full path
                                   # $PATH is searched
#animation.avconv_args: ''         # Additional arguments to pass to avconv
#animation.mencoder_path: 'mencoder'
                                   # Path to mencoder binary. Without full path
                                   # $PATH is searched
#animation.mencoder_args: ''       # Additional arguments to pass to mencoder
#animation.convert_path: 'convert' # Path to ImageMagick's convert binary.
                                   # On Windows use the full path since convert
                                   # is also the name of a system tool.
View Code

 2.1 Matplotlib annotations

咱們首先使用文本註解繪製數節點

# coding:utf-8
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc='0.8') # 決策節點描述
leafNode = dict(boxstyle='round4', fc='0.8') # 葉子節點描述
arrow_args = dict(arrowstyle="<-") # 箭頭描述

# 繪製節點四個參數: 節點描述,節點中心座標, 父節點中心座標, 節點類型
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot():
    fig = plt.figure(1, facecolor='white') # 建立一個新圖像
    fig.clf()                               # 清空繪圖區
    createPlot.ax1 = plt.subplot(111, frameon=False)
    plotNode(u'決策節點', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode(u'葉節點', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

createPlot()

2.2 Constructing a tree of annotations

繪製一顆完整樹須要一些技巧。咱們雖然有x、y座標,可是如何放置全部的樹節點倒是個問題。咱們必須知道有多少個葉節點,以便肯定x軸的長度;咱們還須要知道樹有多少層,以即可以正確肯定y軸的高度。這裏咱們定義兩個新函數getNumLeafs()和getTreeDepth(),來獲取葉節點的數目和樹的層數:

# 2.1. 統計葉子節點數
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    # 若是子節value是字典型則該節點是決策節點,不然是葉子節點
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

# 2.2. 統計數的深度
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    # 若是子節value是字典型則該節點是決策節點,不然是葉子節點
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

這裏爲了方便測試,咱們制定一個樹的字典模型

def retrieveTree(i):
    listOfTree = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTree[i]

myTree = retrieveTree(1)
print getNumLeafs(myTree) # 輸出4
print getTreeDepth(myTree) # 輸出3

如今咱們能夠將前面學到的方法組合在一塊兒,繪製一顆完整的樹。

# 4. 繪製一棵樹
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    getTreeDepth(myTree)
    firstStr = myTree.keys()[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

# 5. 建立畫板
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))  # 設置繪圖板寬度
    plotTree.totalD = float(getTreeDepth(inTree)) # 設置繪圖板高度
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()
myTree = retrieveTree(0)
createPlot(myTree)

3. Testing and storing the classifier

# 6. 使用決策樹的分類函數
def classify(inputTree, featLabels, testVec):
    firstStr = inputTree.keys()[0]    # 得到父節點分裂屬性
    secondDict = inputTree[firstStr]  # 得到子集
    featIndex = featLabels.index(firstStr) # 分裂屬性在屬性列表中序號
    # 對於每一個子節點,若是是葉子節點返回分類結果,若是是決策節點,遞歸調用分類器
    for key in secondDict.keys(): 
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__=='dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

myData, feature = createDataSet() # 得到訓練樣本
print feature
myTree = createTree(myData, feature) # 得到決策樹模型模型
print myTree
predict = classify(myTree, feature, [1, 0]) # 預測
print predict

如今咱們已經建立了使用決策樹的分類器,可是每次使用分類器時,必須從新構造決策樹,下面咱們將介紹如何在硬盤上存儲決策樹分類器。

# 7.1 保存模型
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'w') # 打開或新建文本文件
    pickle.dump(inputTree, fw) # 保存
    fw.close()
# 7.2 加載模型
def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

4. Example: using decision trees to predict contact lens type

Example: using decision trees to predict contact lens type
1. Collect: Text file provided.
2. Prepare: Parse tab-delimited lines.
3. Analyze: Quickly review data visually to make sure it was parsed properly. The final
tree will be plotted with createPlot().
4. Train: Use createTree() from section 3.1.
5. Test: Write a function to descend the tree for a given instance.
6. Use: Persist the tree data structure so it can be recalled without building the
tree; then use it in any application.

 

# 8. 使用決策樹預測隱形眼鏡類型
import treePlotter
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesFeatures = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesFeatures)
treePlotter.createPlot(lensesTree)

5. Summary

決策樹分類器就像帶有終止塊(葉節點)的流程圖,終止塊表示分類結果。開始處理數據集時,咱們首先須要測量集合中數據的不一致性,即信息熵,而後尋找最優方案劃分數據集,直到數據集中全部數據屬於同一類。ID3算法能夠用於劃分標稱型數據集,最優劃分屬性選擇使得信息增益最大的劃分屬性。構建決策樹時,咱們一般採用遞歸的方法將數據集轉化爲決策樹。通常咱們不構造新的數據結構,而是使用Python語言內嵌的數據結構字典存儲樹節點信息。

使用Matplotlib的註解功能,咱們能夠將存儲的樹結構轉化爲容易理解的圖形。Python語言的pickle模塊可用於存儲決策樹的結構。隱形眼鏡的例子代表決策樹可能會產生過多的數據集劃分從而產生過分匹配數據集的問題。咱們能夠經過裁剪決策樹,合併相鄰的沒法產生大量信息增益的葉節點,消除過分匹配的問題。 

相關文章
相關標籤/搜索