SVM算法總結

svm算法通俗的理解在二維上,就是找一分割線把兩類分開,問題是以下圖三條顏色均可以把點和星劃開,但哪條線是最優的呢,這就是咱們要考慮的問題;python

 

首先咱們先假設一條直線爲 W•X+b =0 爲最優的分割線,把兩類分開以下圖所示,那咱們就要解決的是怎麼獲取這條最優直線呢?及W 和 b 的值;在SVM中最優分割面(超平面)就是:能使支持向量和超平面最小距離的最大值;算法

咱們的目標是尋找一個超平面,使得離超平面比較近的點能有更大的間距。也就是咱們不考慮全部的點都必須遠離超平面,咱們關心求得的超平面可以讓全部點中離它最近的點具備最大間距。dom

如上面假設藍色的星星類有5個樣本,並設定此類樣本標記爲Y =1,紫色圈類有5個樣本,並設定此類標記爲 Y =-1,共 T ={(X₁ ,Y₁) , (X₂,Y₂) (X₃,Y₃) .........} 10個樣本,超平面(分割線)爲wx+b=0;  樣本點到超平面的幾何距離爲:函數

     

        此處要說明一下:函數距離和幾何距離的關係;定義上把 樣本| w▪x₁+b|的距離叫作函數距離,而上面公式爲幾何距離,你會發現當w 和b 同倍數增長時候,函數距離也會通倍數增長;簡單個例子就是,樣本 X₁ 到 2wX₁+2b =0的函數距離是wX₁+b =0的函數距離的 2倍;而幾何矩陣不變;測試

        下面咱們就要談談怎麼獲取超平面了?!優化

超平面就是知足支持向量到其最小距離最大,及是求:max [支持向量到超平面的最小距離] ;那隻要算出支持向量到超平面的距離就能夠了吧 ,而支持向量到超平面的最小距離能夠表示以下公式:atom

故最終優化的的公式爲:spa

 根據函數距離和幾何距離能夠得知,w和b增長時候,幾何距離不變,故怎能經過同倍數增長w和 b使的支持向量(距離超平面最近的樣本點)上樣本代入 y(w*x+b) =1,而不影響上面公式的優化,樣本點距離以下:如上圖其r1函數距離爲1,k1函數距離爲1,而其它.net

樣本點的函數距離大於1,及是:y(w•x+b)>=1,把此條件代入上面優化公式候,能夠獲取新的優化公式1-3:code

 


公式1-3見下方:優化最大化分數,轉化爲優化最小化分母,爲了優化方便轉化爲公式1-4

爲了優化上面公式,使用拉格朗日公式和KTT條件優化公式轉化爲:

 

對於上面的優化公式在此說明一下:好比咱們的目標問題是 minf(x)。能夠構造函數L(a,b,x):

 

 

L(a,b,x)=f(x)+ag(x)+bh(x)a0

此時 f(x) 與 maxa,bL(a,b,x) 是等價的。由於 h(x)=0,g(x)0,ag(x)0,因此只有在ag(x)=0 的狀況下 

L(a,b,x) 才能取得最大值,所以咱們的目標函數能夠寫爲minxmaxa,bL(a,b,x)。若是用對偶表達式:maxa,bminxL(a,b,x),

因爲咱們的優化是知足強對偶的(強對偶就是說對偶式子的最優值是等於原問題的最優值的),因此在取得最優值x∗ 的條件下,它知足 :

f(x)=maxa,bminxL(a,b,x)=minxmaxa,bL(a,b,x)=f(x),

 

結合上面的一度的對偶說明故咱們的優化函數以下面,其中a >0

如今的優化方案到上面了,先求最小值,對 w 和 b 分別求偏導能夠獲取以下公式:

 

把上式獲取的參數代入公式優化max值:

化解到最後一步,就能夠獲取最優的a值:


以上就能夠獲取超平面!

但在正常狀況下可能存在一些特異點,將這些特異點去掉後,剩下的大部分點都能線性可分的,有些點線性不能夠分,意味着此點的函數距離不是大於等於1,而是小於1的,爲了解決這個問題,咱們引進了鬆弛變量 ε>=0; 這樣約束條件就會變成爲:

故原先的優化函數變爲:

對加入鬆弛變量後有幾點說明以下圖因此;距離小於1的樣本點離超平面的距離爲d ,在綠線和超平面之間的樣本點都是由損失的,

其損失變量和距離d 的關係,能夠看出 ξ = 1-d , 當d >1的時候會發現ξ =0,當 d<1 的時候 ξ = 1-d ;故能夠畫出損失函數圖,以下圖1-7;樣式就像翻書同樣,咱們把這個損失函數叫作 hinge損失; 

下面咱們簡單的就來討論一下核函數:核函數的做用其實很簡單就是把低維映射到高維中,便於分類。核函數有高斯核等,下面就直接上圖看參數對模型的影響,從下圖能夠了解,當C變化時候,容錯變小,泛化能力變小;當選擇高斯核函數的時候,隨時R參數調大,準確高提升,最終有過擬合風險;

下面就直接上代碼了(鳶尾花SVM二特徵分類):

 

[python]  view plain  copy
 
  1. iris_feature = u'花萼長度', u'花萼寬度', u'花瓣長度', u'花瓣寬度'  
  2.   
  3.   
  4. if __name__ == "__main__":  
  5.     path = 'iris.data'  # 數據文件路徑  
  6.     data = pd.read_csv(path, header=None)  
  7.     x, y = data[range(4)], data[4]  
  8.     y = pd.Categorical(y).codes  
  9.     x = x[[0, 1]]  
  10.     x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1, train_size=0.6)  
  11.   
  12.     # 分類器  
  13.     clf = svm.SVC(C=0.3, kernel='linear', decision_function_shape='ovo')  
  14.     clf.fit(x_train, y_train.ravel())  
  15.   
  16.     # 準確率  
  17.     print clf.score(x_train, y_train)  # 精度  
  18.     print '訓練集準確率:', accuracy_score(y_train, clf.predict(x_train))  
  19.     print clf.score(x_test, y_test)  
  20.     print '測試集準確率:', accuracy_score(y_test, clf.predict(x_test))  
  21.     x1_min, x2_min = x.min()  
  22.     x1_max, x2_max = x.max()  
  23.     x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]  # 生成網格採樣點  
  24.     grid_test = np.stack((x1.flat, x2.flat), axis=1)  # 測試點  
  25.   
  26.     print 'grid_test = \n', grid_test  
  27.     Z = clf.decision_function(grid_test)  
  28.     Z = Z[:,0].reshape(x1.shape)  
  29.     print "decision_function:",Z  
  30.     grid_hat = clf.predict(grid_test)  
  31.     grid_hat = grid_hat.reshape(x1.shape)  
  32.     mpl.rcParams['font.sans-serif'] = [u'SimHei']  
  33.     mpl.rcParams['axes.unicode_minus'] = False  
  34.   
  35.     cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])  
  36.     cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])  
  37.     plt.figure(facecolor='w')  
  38.     plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)  
  39.     plt.scatter(x[0], x[1], c=y, edgecolors='k', s=50, cmap=cm_dark)      # 樣本  
  40.     plt.scatter(x_test[0], x_test[1], s=120, facecolors='none', zorder=10)     # 圈中測試集樣本  
  41.     plt.xlabel(iris_feature[0], fontsize=13)  
  42.     plt.ylabel(iris_feature[1], fontsize=13)  
  43.     plt.xlim(x1_min, x1_max)  
  44.     plt.ylim(x2_min, x2_max)  
  45.     plt.title(u'鳶尾花SVM二特徵分類', fontsize=16)  
  46.     plt.grid(b=True, ls=':')  
  47.     plt.show()  

最後畫圖以下:

 

相關文章
相關標籤/搜索