PLA感知學習算法

  1 #include <vector>
  2 #include<iomanip>
  3 #include <string>
  4 #include<stdio.h>
  5 #include<string.h> 
  6 #include <fstream>
  7 #include <iostream>
  8 #include<set>
  9 #include<algorithm>
 10 #include<cstdio>
 11 #include<iomanip>
 12 #include<map>
 13 #include<cmath>
 14 #define col 41
 15 #define row 7000
 16 
 17 using namespace std;
 18 double label[8010][80];           //訓練集 
 19 double test_label[8010][80];      //測試集
 20 double valition_label[8010][80];  //驗證集 
 21 string s[8010]; 
 22 string ss[8010]; 
 23 string s2[8010]; 
 24 
 25 
 26 //logitic函數,將負無窮到正無窮 轉化 -1到 1 
 27  
 28 double logistic(double n){
 29     
 30    return 1/(1+exp(-1.0*n));
 31     
 32 }
 33 
 34 double geterror(int n){
 35     return 0.001;
 36 }
 37 
 38 double cut_t(string s, int t){
 39     string str = s;
 40     int r = 0;
 41     double count = 0.0;
 42     bool flag = true;
 43     double flag1 = 1.0;
 44     double sum = 0.0;
 45     for(int i=0;i<str.length();i++){
 46         
 47         if(r==t && str[i] == '-'){
 48             flag1 = -1;
 49             continue;
 50         }
 51         if(str[i]==','){
 52             r++;
 53             continue;
 54         }
 55         if(r==t){
 56             if(flag == false){
 57                 count ++;
 58             }
 59             if(str[i] == '.'){
 60                 flag = false;
 61             }
 62             else {
 63                 sum = sum + (str[i] - '0') * 1.0;
 64                 sum = sum * 10;
 65             }
 66 
 67         }
 68     }
 69   
 70     for(int i=0;i<=count;i++){
 71          sum = sum/10;
 72     }
 73     return sum*flag1;
 74 }
 75 
 76 
 77 
 78 int main()
 79 {
 80 
 81 
 82 /*************************************************讀文件***********************************************************/ 
 83   
 84     fstream myfile("C:\\AI_data\\lab5\\train.csv");
 85     fstream valition("C:\\AI_data\\lab5\\valition.csv");
 86     fstream test("C:\\AI_data\\lab5\\test.csv");
 87     
 88     
 89     int num=0;
 90     string temp;
 91     if(!myfile.is_open())
 92     {
 93         cout << "1未成功打開文件" << endl;
 94     }
 95     while(getline(myfile,temp))   //讀入文本中的詞 
 96     {
 97         s[num] = temp;
 98         num++;
 99     }        
100     
101    
102     int num1 = 0;
103     string temp1;
104     if(!test.is_open())
105     {
106         cout << "2未成功打開文件" << endl;
107     }
108     while(getline(test,temp1))   //讀入文本中的詞 
109     {
110         ss[num1] = temp1;
111         num1++;
112     }
113        
114        
115     int num2 = 0;
116     string temp2;
117     if(!valition.is_open())
118     {
119         cout << "3未成功打開文件" << endl;
120     }
121     while(getline(valition,temp2))   //讀入文本中的詞 
122     {
123         s2[num2] = temp2;
124         num2++;
125     }
126     
127    
128 /***********************************************處理文本********************************************************************/ 
129     for(int i=0; i<num; i++){
130         
131         int len = s[i].length(); 
132         string str = s[i];
133         char t[8000]="";
134         for(int j=0;j<col;j++){
135             label[i][0] = 1.0;                //須要在每個樣例前面加上一個 1 
136             label[i][j+1] = cut_t(s[i],j);
137         }
138         // for(int j=0;j<=col;j++)  cout<<label[i][j]<<" ";
139         //cout<<endl;
140         for(int w=0;w<len;w++){
141              t[w] = str[w]; 
142         }
143         const char *d = " , \n" ;
144         char* p = strtok(t,d); 
145         while(p)
146         {
147             p=strtok(NULL,d);
148         }
149     }
150     
151     for(int i=0; i<num1; i++){
152         
153         int len1 = ss[i].length(); 
154         string str1 = ss[i];
155         char tt[8000]="";
156         //cout<<ss[i]<<endl;
157         for(int j=0;j<col-1;j++){
158             test_label[i][0] = 1.0;                //須要在每個樣例前面加上一個 1 
159             test_label[i][j+1] = cut_t(ss[i],j);
160             //cout<<test_label[i][j]<<" ";
161         }
162          //for(int j=0;j<67;j++)  cout<<label[i][j]<<endl;
163         //cout<<endl;
164         for(int w=0;w<len1;w++){
165              tt[w] = str1[w]; 
166         }
167         const char *d = " , \n" ;
168         char* p = strtok(tt,d); 
169         while(p)
170         {
171             p=strtok(NULL,d);
172         }
173     }
174     
175     for(int i=0; i<num2; i++){
176         
177         int len2 = s2[i].length(); 
178         string str2 = s2[i];
179         char t2[8000]="";
180         for(int j=0;j<col;j++){
181             valition_label[i][0] = 1.0;                //須要在每個樣例前面加上一個 1 
182             valition_label[i][j+1] = cut_t(s2[i],j);
183         }
184         // for(int j=0;j<=col;j++)  cout<<label[i][j]<<" ";
185         //cout<<endl;
186         for(int w=0;w<len2;w++){
187              t2[w] = str2[w]; 
188         }
189         const char *d2 = " , \n" ;
190         char* p2 = strtok(t2,d2); 
191         while(p2)
192         {
193             p2=strtok(NULL,d2);
194         }
195     }
196 /***************************************************************** PLA算法執行  ************************************************/  
197  
198 
199  
200  
201     double w[col];                   //初始的 w[] 數組 
202     double new_w[col];
203     double zhishu[row];
204     for(int j=0;j<col;j++){
205         w[j] = 1.0; 
206     }
207     for(int ui=0;ui<row;ui++){
208         zhishu[ui] = 0.0;
209     }
210     
211     int a = 6000;                    //因爲不能所有劃分,因此設立一個最大次數 
212     double error = 0.5;
213     while(a--){
214         //double error = geterror(a);  //調整步長 
215                 
216         for(int j=0;j<col;j++){    //初始化數組,用來更新w[]數組 
217             new_w[j] = 0.0;
218         } 
219         for(int i=0;i<num;i++){    // 遍歷全部樣本進行一輪迭代 
220             for(int j=0;j<col;j++){
221                zhishu[i] += label[i][j]*w[j]; // 對每個導數進行存儲 
222             }                     
223             //進行logistic變換 
224             zhishu[i] = logistic(zhishu[i]) - label[i][col];  
225         }
226         bool flag = true;    //判斷是否收斂 
227         for(int jt=0;jt<col;jt++){
228             for(int it=0;it<num;it++){  //更新 w[] 
229                  new_w[jt] +=  label[it][jt]*zhishu[it];
230             }
231             new_w[jt] = new_w[jt]*error;   
232             if(new_w[jt] != 0) flag = false; 
233             w[jt] = w[jt] - new_w[jt];  //爲下一次迭代 w[] 
234         }
235         if(flag){  //若是收斂 
236              cout<<a<<endl;
237              cout<<"完美收斂,提早結束"<<endl; 
238              break;
239         }
240         
241     }
242     
243     //統計各個指標 
244     double TP = 0.0;
245     double FN = 0.0;
246     double TN = 0.0;
247     double FP = 0.0;
248     double Acc = 0.0;
249     double Rec = 0.0;
250     double Pre = 0.0;
251     double F1 = 0.0;
252      
253  
254     for(int i=0;i<num2;i++){
255        int flag1 = 1;
256        double sum2 = 0.0;
257        for(int j=0;j<col;j++){
258             sum2 += valition_label[i][j]*w[j];
259        }
260        if(logistic(sum2) < 0.5) flag1 = 0;
261        else flag1 = 1;
262        
263        
264        if(flag1 == 1 && valition_label[i][col] == 1) TP++;
265        else if(flag1 == 0 && valition_label[i][col] == 1) FN++;
266        else if(flag1 == 0 && valition_label[i][col] == 0) TN++;
267        else FP++;
268     }
269     cout<<"TP = "<<TP<<endl;
270     cout<<"TN = "<<TN<<endl;
271     cout<<"FN = "<<FN<<endl;
272     cout<<"FP = "<<FP<<endl;
273     Acc = (TP+TN)/(TP+TN+FP+FN);
274     Rec = TP/(TP+FN);
275     Pre = TP/(TP+FP);
276     F1 = 2*Pre*Rec / (Pre+Rec);
277 
278     cout<<"Acc = "<<Acc<<endl;
279     cout<<"Rec = "<<Rec<<endl;
280     cout<<"Pre = "<<Pre<<endl;
281     cout<<"F1 = "<<F1<<endl;
282 
283 
284 
285     for(int k=0;k<num1;k++){
286         
287        int flag3 = 0;
288        double sum3 = 0.0;
289        for(int j=0;j<col;j++){
290             sum3 += test_label[k][j]*w[j];
291        }
292 
293        if(logistic(sum3) <  0.5 ) flag3 = 0;
294        else flag3 = 1;
295        cout<<flag3<<endl; 
296     }
297     
298     test.close(); 
299     myfile.close();
300     return 0;
301 }

上面是原始的PLA實現,下面是PLA基於口袋算法的優化:ios

  1 #include <vector>
  2 #include<iomanip>
  3 #include <string>
  4 #include<stdio.h>
  5 #include<string.h> 
  6 #include <fstream>
  7 #include <iostream>
  8 #include<set>
  9 #include<algorithm>
 10 #include<cstdio>
 11 #include<iomanip>
 12 #include<map>
 13 #include<cmath>
 14 using namespace std;
 15 double label[4010][80];
 16 string s[4010];
 17 
 18 double cut_t(string s, int t){
 19     string str = s;
 20     int r = 0;
 21     double count = 0.0;
 22     bool flag = true;
 23     double flag1 = 1.0;
 24     double sum = 0.0;
 25     for(int i=0;i<str.length();i++){
 26         
 27         if(r==t && str[i] == '-'){
 28             flag1 = -1;
 29             continue;
 30         }
 31         if(str[i]==','){
 32             r++;
 33             continue;
 34         }
 35         if(r==t){
 36             if(flag == false){
 37                 count ++;
 38             }
 39             if(str[i] == '.'){
 40                 flag = false;
 41             }
 42             else {
 43                 sum = sum + (str[i] - '0') * 1.0;
 44                 sum = sum * 10;
 45             }
 46 
 47         }
 48     }
 49   
 50     for(int i=0;i<=count;i++){
 51          sum = sum/10;
 52     }
 53     return sum*flag1;
 54 }
 55 
 56 
 57 
 58 int main()
 59 {
 60 
 61 
 62 /*************************************************讀文件***********************************************************/ 
 63   
 64     fstream myfile("F:\\AI_data\\lab3\\train.txt");
 65     int num=0;
 66     string temp;
 67     if (!myfile.is_open())
 68     {
 69         cout << "未成功打開文件" << endl;
 70     }
 71     while(getline(myfile,temp))   //讀入文本中的詞 
 72     {
 73         s[num] = temp;
 74         num++;
 75     }        
 76     
 77    
 78 /***********************************************處理文本********************************************************************/ 
 79     for(int i=0; i<num; i++){
 80         
 81         int len = s[i].length(); 
 82         string str = s[i];
 83         char t[8000]="";
 84         for(int j=0;j<66;j++){
 85             label[i][0] = 1.0;                //須要在每個樣例前面加上一個 1 
 86             label[i][j+1] = cut_t(s[i],j);
 87         }
 88          //for(int j=0;j<67;j++)  cout<<label[i][j]<<endl;
 89         // cout<<endl;
 90         for(int w=0;w<len;w++){
 91              t[w] = str[w]; 
 92         }
 93         const char *d = " , \n" ;
 94         char* p = strtok(t,d); 
 95         while(p)
 96         {
 97             p=strtok(NULL,d);
 98         }
 99     }
100   
101 /***************************************************************** PLA算法執行  ************************************************/  
102  
103     double w[66];                   //初始的 w[] 數組
104     double change_w[66]; 
105     //double w[7];
106     double store[4010];
107     double sum = 0.0;
108     for(int j=0;j<66;j++){
109         w[j] = 1.0;
110         change_w[j] = 1.0;
111     }
112     int a = 2000;
113     
114      
115     while(a--){                             //規定迭代次數 
116         
117         bool flag2 = true;
118         long double counter_right1 = 0;     //兩次的正確的數目統計 
119         long double counter_right2 = 0;
120         int dex = 0;
121          
122         for(int i=0;i<num;i++){             //遍歷全部數據 
123             
124             sum = 0.0;
125             
126             for(int j=0;j<66;j++){          //進行計算 
127                 
128                 sum += label[i][j]*w[j];
129                 //cout<< i << "   "<<j<<endl; 
130             }
131             //cout<<"sum= "<<sum<<endl;
132             int flag = 0;
133             
134             if(sum > 0.0){                 //對結果的符號進行判斷 
135                 flag = 1;
136             }
137             else{
138                 flag = -1;
139             }
140         
141             //cout<<flag << "   "<<label[i][66]<<endl; 
142             if(flag != label[i][66] ){    //判斷結果是不是正確的,不正確須要考慮這個w[] 
143                 if(flag2){
144                     for(int k=0;k<66;k++){
145                        change_w[k] = w[k] + label[i][k]*label[i][66];
146                        dex = i;
147                        //cout<<w[k]<<endl;
148                     }
149                 }
150                 flag2 = false;           //一次只考慮第一個不正確的 w[] 
151             }
152             else counter_right1++;       //記錄第一個 w[] 的正確率 
153         } 
154         
155         
156         for(int i=0;i<num;i++){         //遍歷全部數據 
157             
158             sum = 0.0;
159             
160             for(int j=0;j<66;j++){       //用第二個w[]進行迭代 
161                 
162                 sum += label[i][j]*change_w[j];
163                 //cout<< i << "   "<<j<<endl; 
164             }
165             //cout<<"sum= "<<sum<<endl;
166             int flag = 0;
167              
168             if(sum > 0.0){              //算出結果的符號 
169                 flag = 1;
170             }
171             else{
172                 flag = -1;
173             }
174                                        //記錄正確率 
175             //cout<<flag << "   "<<label[i][66]<<endl; 
176             if(flag == label[i][66] )  counter_right2++;
177         }
178         
179         
180         //兩個w[]數組正確率比較 ,第一個正確率高則返回原來的w[],不然w[] 替換爲更新後的,進入下一輪迭代 
181         if(counter_right1 > counter_right2){
182             for(int j=0;j<66;j++){
183                 w[j] = change_w[j] - label[dex][j]*label[dex][66];   
184             }
185         }
186         else{
187             for(int j=0;j<66;j++){
188                 w[j] = change_w[j] ;
189             } 
190         }         
191         
192     }
193     
194     
195     double TP = 0.0;
196     double FN = 0.0;
197     double TN = 0.0;
198     double FP = 0.0;
199     double Acc = 0.0;
200     double Rec = 0.0;
201     double Pre = 0.0;
202     double F1 = 0.0;
203      
204     for(int i=0;i<num;i++){
205        int flag1 = 1;
206        double sum1 = 0.0;
207        for(int j=0;j<66;j++){
208             sum1 += label[i][j]*w[j];
209             //cout<<" w[] = "<<w[j]<<endl; 
210        }
211        cout<<sum1<<endl;
212        
213        if(sum1 >= 0 ) flag1 = 1;
214        else flag1 = -1;
215        
216        cout<<flag1<<" ";
217        cout<<label[i][66]<<endl;
218        if(flag1 == 1 && label[i][66] == 1) TP++;
219        else if(flag1 == -1 && label[i][66] == 1) FN++;
220        else if(flag1 == -1 && label[i][66] == -1) TN++;
221        else if(flag1 == 1 && label[i][66] == -1)FP++;
222        
223     }
224     cout<<TP<<endl;
225     cout<<TN<<endl;
226     cout<<FN<<endl;
227     cout<<TN<<endl;
228     Acc = (TP+TN)/(TP+TN+FP+FN);
229     Rec = TP/(TP+FN);
230     Pre = TP/(TP+FP);
231     F1 = 2*Pre*Rec / (Pre+Rec);
232 
233     cout<<"Acc = "<<Acc<<endl;
234     cout<<"Rec = "<<Rec<<endl;
235     cout<<"Pre = "<<Pre<<endl;
236     cout<<"F1 = "<<F1<<endl;
237     for(int k=0;k<66;k++){
238         cout<< w[k] <<endl;
239     }
240      
241     myfile.close();
242     return 0;
243 }
相關文章
相關標籤/搜索