libsvm代碼閱讀:關於svm_group_classes函數分析

目前libsvm最新的version是3.17,主要的改變是在svm_group_classes函數中加了幾行代碼。官方的說明以下:web

Version 3.17 released on April Fools' day, 2013. We slightly adjust the way class labels are handled internally. By default labels are ordered by their first occurrence in the training set. Hence for a set with -1/+1 labels, if -1 appears first, then internally -1 becomes +1. This has caused confusion. Now for data with -1/+1 labels, we specifically ensure that internally the binary SVM has positive data corresponding to the +1 instances. For developers, see changes in the subrouting svm_group_classes of svm.cpp. 數組

本文就對這個函數進行分析:app

svm_group_classes函數的功能是:group training data of the same class函數

Important:如何將一堆數據歸類到一塊兒,同類的連續存儲!可參考這個函數。
oop

函數原型以下:this

[cpp]   view plain copy 在CODE上查看代碼片 派生到個人代碼片
<EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=1&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)  

主要的輸入是prob這個指針,它指向svm_group_classes將要處理的樣本數據集,另外幾個形參是指針類型,能夠至關於輸出數據,其中:spa

  1. nr_class_ret——統計得出樣本集的類別總數
  2. label_ret——指向存儲類別標號的數組
  3. start_ret——指向存儲每一個類別的起始位置的數組
  4. count_tet——指向存儲每一個類別的樣本個數的數組
  5. perm——指向原始數據的索引數組
下面,先看一部分代碼,這部分代碼中的for循環的功能:統計類別總數、將相應的相同類別y[i]賦到相應的label,並統計各個類別的樣本數量count。
設一個例子:{ 有6個樣本,總共4類,其中y[0]=y[1],y[2]=y[3],y[4],y[5] },則for循環的運行過程以下所示:
i=0  label[0]=y[0],           data_label[0]=0
i=1  label[0]=y[0]=y[1],   data_label[1]=0 count[0]=2
i=2  label[1]=y[2],           data_label[2]=1
i=3  label[1]=y[2]=y[3],   data_label[3]=1 count[1]=2
i=4  label[2]=y[4],           data_label[2]=2 count[2]=1
i=5  label[3]=y[5],           data_label[2]=3 count[3]=1

[cpp]   view plain copy 在CODE上查看代碼片 派生到個人代碼片
<EMBED id=ZeroClipboardMovie_2 height=18 name=ZeroClipboardMovie_2 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=2&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data  
  2. // perm, length l, must be allocated before calling this subroutine  
  3. static void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)  
  4. {  
  5.     int l = prob->l;//樣本總數  
  6.     int max_nr_class = 16;//不夠的話,自動增加爲原來的兩倍(見下文)  
  7.     int nr_class = 0;  
  8.     int *label = Malloc(int,max_nr_class);//Malloc(type,n) (type *)malloc((n)*sizeof(type))  
  9.     int *count = Malloc(int,max_nr_class);  
  10.     int *data_label = Malloc(int,l);      
  11.     int i;  
  12.   
  13.     for(i=0;i<l;i++)  
  14.     {  
  15.         int this_label = (int)prob->y[i];//將類別賦給this_label  
  16.         int j;  
  17.         for(j=0;j<nr_class;j++)  
  18.         {  
  19.             if(this_label == label[j])//雖然剛開始label裏面沒值,可是第一步循環本內層也沒有被運行  
  20.             {  
  21.                 ++count[j];  
  22.                 break;  
  23.             }  
  24.         }  
  25.         data_label[i] = j;  
  26.         if(j == nr_class)  
  27.         {  
  28.             if(nr_class == max_nr_class)  
  29.             {  
  30.                 max_nr_class *= 2;//擴大最大類別數  
  31.                 label = (int *)realloc(label,max_nr_class*sizeof(int));  
  32.                 count = (int *)realloc(count,max_nr_class*sizeof(int));  
  33.             }  
  34.             label[nr_class] = this_label;  
  35.             count[nr_class] = 1;//這個是1  
  36.             ++nr_class;  
  37.         }  
  38.     }  


本version更新部分:本部分主要是處理二類分類,當第一個出現的是-1時,負責把-1和+1的數據對調。.net

[cpp]   view plain copy 在CODE上查看代碼片 派生到個人代碼片
<EMBED id=ZeroClipboardMovie_3 height=18 name=ZeroClipboardMovie_3 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=3&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. //  
  2. // Labels are ordered by their first occurrence in the training set.   
  3. // However, for two-class sets with -1/+1 labels and -1 appears first,   
  4. // we swap labels to ensure that internally the binary SVM has positive data corresponding to the +1 instances.  
  5. //  
  6. if (nr_class == 2 && label[0] == -1 && label[1] == 1)  
  7. {  
  8.     swap(label[0],label[1]);  
  9.     swap(count[0],count[1]);  
  10.     for(i=0;i<l;i++)  
  11.     {  
  12.         if(data_label[i] == 0)  
  13.             data_label[i] = 1;  
  14.         else  
  15.             data_label[i] = 0;  
  16.     }  
  17. }  


下面這一部分代碼是用來計算每一個類別的起始位置start、以及各個樣本分類後的在原始數據中的索引位置perm數組。其中perm[i]=j: i表示當前同類樣本位置,j表示原始數據位置。指針

Important:如何將一堆數據歸類到一塊兒,同類的連續存儲!可參考這個函數。code

[cpp]   view plain copy 在CODE上查看代碼片 派生到個人代碼片
<EMBED id=ZeroClipboardMovie_4 height=18 name=ZeroClipboardMovie_4 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=4&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. int *start = Malloc(int,nr_class);  
  2. start[0] = 0;  
  3. for(i=1;i<nr_class;i++)  
  4.     start[i] = start[i-1]+count[i-1];  
  5. for(i=0;i<l;i++)  
  6. {  
  7.     perm[start[data_label[i]]] = i;  
  8.     ++start[data_label[i]];  
  9. }  
  10. start[0] = 0;  
  11. for(i=1;i<nr_class;i++)  
  12.     start[i] = start[i-1]+count[i-1];  
  13.   
  14. *nr_class_ret = nr_class;  
  15. *label_ret = label;  
  16. *start_ret = start;  
  17. *count_ret = count;  
  18. free(data_lab
相關文章
相關標籤/搜索