調參過程當中的參數 學習率,權重衰減,衝量(learning_rate , weight_decay , momentum)

不管是深度學習仍是機器學習,大多狀況下訓練中都會遇到這幾個參數,今天依據我本身的理解具體的總結一下,可能會存在錯誤,還請指正.
learning_rate , weight_decay , momentum這三個參數的含義. 並附上demo.
 
咱們會使用一個例子來講明一下:
            好比咱們有一堆數據 ,咱們只知道這對數據是從一個 黑盒中獲得的,咱們如今要尋找到那個具體的函數f(x),咱們定義爲目標函數T.
          咱們如今假定有存在這個函數而且這個函數爲:
                                        
         咱們如今要使用這對數據來訓練目標函數. 咱們能夠設想若是存在一個這個函數,一定知足{x,y}全部的關係,也就是說:     
                                        
         那麼最理想的狀況下 :    ,那麼咱們不妨定義這樣一個優化目標函數:
                                    
        對於這堆數據,咱們認爲當Loss(W)對於全部的pair{x,y}都知足 Loss(W)趨近於或者等於0時,咱們認爲咱們找到這個理想的目標函數T. 也就是此時  .
      以上,咱們發現尋找的目標函數的問題,已經成功的轉移爲求解: 
                                        
      也就是Loss 越小, f(x)越接近咱們尋找的目標函數T.
那麼說了這麼多,這個和咱們說的學習率learning_rate有什麼關係呢?
                既然咱們知道了咱們當前的f(x)和目標函數的T的偏差,那麼咱們能夠將這個偏差轉移到每個參數上,也就是變成每個參數w和目標函數T的參數w_t的偏差. 而後咱們就以必定的幅度stride來縮小和真實值的距離,咱們稱這個stride爲學習率learning_rate 並且咱們就是這麼作的.
                咱們用公式表述就是:
                        咱們的偏差(損失)Loss:    
                                        
 
 
 
 
                咱們這一個凸函數. 咱們先對這個函數進行各個份量求偏導.
            
對於w0的偏導數:
                
那麼對於份量w0承擔的偏差爲:
                         而且這個偏差帶方向.
那麼咱們須要使咱們當前的w0更加接近目標函數的T的w0_t參數.咱們須要作運算:
                 (梯度降低算法)
來更新wo的值. 同理其餘參數w,而這個學習率就是來控制咱們每次靠近真實值的幅度,爲何要這麼作呢?
由於咱們表述的偏差只是一種空間表述形式咱們可使用均方差也可使用絕對值,還可使用對數,以及交叉熵等等,因此只能大體的反映,並不精確,就想咱們問路同樣,別人告訴咱們直走五分鐘,有的人走的快,有的人走的慢,因此若是走的快的話,當再次問路的時候,就會發現走多了,而折回來,這就是咱們訓練過程當中的loss曲線震盪嚴重的緣由之一. 因此學習率要設置在合理的大小.

好了說了這麼多,這是學習率. 那麼什麼是權重衰減weight_decay呢? 有什麼做用呢?
          咱們接着看上面的這個Loss(w),咱們發現若是參數過多的話,對於高位的w3,咱們對其求偏導:
            
咱們發現w3開始大於1的時候,w3會調節的很快,幅度很大,從而使得特徵x3變爲異常敏感.從而出現過擬合(overfitting).
       這個時候,咱們須要約束一下w2,w3等高階參數的大小,因而咱們對Loss增長一個懲罰項,使得Loss的正反方向,不該該只由f(x) -y 決定,而還應該加上一個 ;因而Loss變成了:
   
咱們繼續對Loss求解偏導數:
對wo求偏導:
                
 
 


對w3求偏導:
            
 
咱們發現當x3值過大時,會改變Loss的導數的方向.而來抑制w2,w3等高階函數的繼續增加. 然而這樣抑制並非很靈活,因此咱們在前面加入一個係數 ,這個係數在數學上稱之爲拉格朗日乘子係數,也就是咱們用到的weight_decay. 這樣咱們能夠經過調節weight_decay係數,來調節w3,w2等高階的增加程度。加入weight_decay後的公式:
 
 
從公式能夠看出 ,weight_decay越大,抑制越大,w2,w3等係數越小,weight_decay越小,抑制越小,w2,w3等係數越大


那麼衝量momentum又是啥?
     咱們在使用梯度降低法,來調整w時公式是這樣的:
        
咱們每一次都是計算當前的梯度:
                
這樣會發現對於那些梯度比較小的地方,參數w更新的幅度比較小,訓練變得漫長,或者收斂慢.有時候遇到非最優的凸點,會出現衝不出去的現象.
而衝量加進來是一種快速效應.藉助上一次的勢能來和當前的梯度來調節當前的參數w.
公式表達爲:
            
這樣能夠有效的避免掉入凸點沒法衝出來,並且收斂速度也快不少.
 
附上demo: 使用mxnet編碼.
  1 //
  2 // Created by xijun1 on 2017/12/14.
  3 //
  4 
  5 #include <iostream>
  6 #include <vector>
  7 #include <string>
  8 #include <mxnet/mxnet-cpp/MxNetCpp.h>
  9 #include <mxnet/mxnet-cpp/op.h>
 10 
 11 namespace  mlp{
 12     class MlpNet{
 13     public:
 14         static mx_float OutputAccuracy(mx_float* pred, mx_float* target) {
 15             int right = 0;
 16             for (int i = 0; i < 128; ++i) {
 17                 float mx_p = pred[i * 10 + 0];
 18                 float p_y = 0;
 19                 for (int j = 0; j < 10; ++j) {
 20                     if (pred[i * 10 + j] > mx_p) {
 21                         mx_p = pred[i * 10 + j];
 22                         p_y = j;
 23                     }
 24                 }
 25                 if (p_y == target[i]) right++;
 26             }
 27             return right / 128.0;
 28         }
 29 
 30        static void net(){
 31             using mxnet::cpp::Symbol;
 32             using mxnet::cpp::NDArray;
 33 
 34             Symbol x = Symbol::Variable("X");
 35             Symbol y = Symbol::Variable("label");
 36 
 37             std::vector<std::int32_t> shapes({512 , 10});
 38             //定義一個兩層的網絡. wx + b
 39             Symbol weight_0 = Symbol::Variable("weight_0");
 40             Symbol biases_0 = Symbol::Variable("biases_0");
 41 
 42             Symbol fc_0 = mxnet::cpp::FullyConnected("fc_0",x,weight_0,biases_0
 43                     ,512);
 44 
 45             Symbol output_0 = mxnet::cpp::LeakyReLU("relu_0",fc_0,mxnet::cpp::LeakyReLUActType::kLeaky);
 46 
 47             Symbol weight_1 = Symbol::Variable("weight_1");
 48             Symbol biases_1 = Symbol::Variable("biases_1");
 49             Symbol fc_1 = mxnet::cpp::FullyConnected("fc_1",output_0,weight_1,biases_1,10);
 50             Symbol output_1 = mxnet::cpp::LeakyReLU("relu_1",fc_1,mxnet::cpp::LeakyReLUActType::kLeaky);
 51             Symbol pred = mxnet::cpp::SoftmaxOutput("softmax",output_1,y);  //目標函數,loss函數
 52             mxnet::cpp::Context ctx = mxnet::cpp::Context::cpu( 0);
 53 
 54             //定義輸入數據
 55             std::shared_ptr< mx_float > aptr_x(new mx_float[128*28] , [](mx_float* aptr_x){ delete [] aptr_x ;});
 56             std::shared_ptr< mx_float > aptr_y(new mx_float[128] , [](mx_float * aptr_y){ delete [] aptr_y ;});
 57 
 58             //初始化數據
 59             for(int i=0 ; i<128 ; i++){
 60                 for(int j=0;j<28 ; j++){
 61                     //定義x
 62                     aptr_x.get()[i*28+j]= i % 10 +0.1f;
 63                 }
 64 
 65                 //定義y
 66                 aptr_y.get()[i]= i % 10;
 67             }
 68            std::map<std::string, mxnet::cpp::NDArray> args_map;
 69            //導入數據
 70            NDArray arr_x(mxnet::cpp::Shape(128,28),ctx, false);
 71            NDArray arr_y(mxnet::cpp::Shape( 128 ),ctx,false);
 72            //將數據轉換到NDArray中
 73            arr_x.SyncCopyFromCPU(aptr_x.get(),128*28);
 74            arr_x.WaitToRead();
 75 
 76            arr_y.SyncCopyFromCPU(aptr_y.get(),128);
 77            arr_y.WaitToRead();
 78 
 79            args_map["X"]=arr_x.Slice(0,128).Copy(ctx) ;    
 80            args_map["label"]=arr_y.Slice(0,128).Copy(ctx);
 81            NDArray::WaitAll();
 82             //綁定網絡
 83            mxnet::cpp::Executor *executor = pred.SimpleBind(ctx,args_map);
 84             //選擇優化器
 85            mxnet::cpp::Optimizer *opt = mxnet::cpp::OptimizerRegistry::Find("sgd");
 86            mx_float learning_rate = 0.0001; //學習率
 87            mx_float weight_decay = 1e-4; //權重
 88            opt->SetParam("momentum", 0.9)
 89                    ->SetParam("lr", learning_rate)
 90                    ->SetParam("wd", weight_decay);
 91            //定義各個層參數的數組
 92            NDArray arr_w_0(mxnet::cpp::Shape(512,28),ctx, false);
 93            NDArray arr_b_0(mxnet::cpp::Shape( 512 ),ctx,false);
 94            NDArray arr_w_1(mxnet::cpp::Shape(10 , 512 ) , ctx , false);
 95            NDArray arr_b_1(mxnet::cpp::Shape( 10 ) , ctx , false);
 96 
 97            //初始化權重參數
 98            arr_w_0 = 0.01f;
 99            arr_b_1 = 0.01f;
100            arr_w_1 = 0.01f;
101            arr_b_1 = 0.01f;
102 
103             //初始化參數
104             executor->arg_dict()["weight_0"]=arr_w_0;
105             executor->arg_dict()["biases_0"]=arr_b_0;
106             executor->arg_dict()["weight_1"]=arr_w_1;
107             executor->arg_dict()["biases_1"]=arr_b_1;
108 
109             mxnet::cpp::NDArray::WaitAll();
110             //訓練
111             std::cout<<" Training "<<std::endl;
112 
113             int max_iters = 20000;  //最大迭代次數
114            //獲取訓練網絡的參數列表
115            std::vector<std::string>  args_name = pred.ListArguments();
116             for (int iter = 0; iter < max_iters ; ++iter) {
117                 executor->Forward(true);
118                 executor->Backward();
119 
120                 if(iter % 100 == 0){
121                     std::vector<NDArray> & out = executor->outputs;
122                     std::shared_ptr<mx_float> tp_x( new mx_float[128*28] ,
123                                                     [](mx_float * tp_x){ delete [] tp_x ;});
124                     out[0].SyncCopyToCPU(tp_x.get(),128*10);
125                     NDArray::WaitAll();
126                     std::cout<<"epoch "<<iter<<"  "<<"Accuracy: "<<  OutputAccuracy(tp_x.get() , aptr_y.get())<<std::endl;
127                 }
128                 //args_name.
129                 for(size_t arg_ind=0; arg_ind<args_name.size(); ++arg_ind){
130                     //執行
131                     if(args_name[arg_ind]=="X" || args_name[arg_ind]=="label")
132                         continue;
133 
134                     opt->Update(arg_ind,executor->arg_arrays[arg_ind],executor->grad_arrays[arg_ind]);
135                 }
136                 NDArray::WaitAll();
137 
138             }
139 
140 
141         }
142     };
143 }
144 
145 int main(int argc , char * argv[]){
146     mlp::MlpNet::net();
147     MXNotifyShutdown();
148     return EXIT_SUCCESS;
149 }
View Code

結果:ios

Training 
epoch 0  Accuracy: 0.09375
epoch 100  Accuracy: 0.304688
epoch 200  Accuracy: 0.195312
epoch 300  Accuracy: 0.203125
epoch 400  Accuracy: 0.304688
epoch 500  Accuracy: 0.296875
epoch 600  Accuracy: 0.304688
epoch 700  Accuracy: 0.304688
epoch 800  Accuracy: 0.398438
epoch 900  Accuracy: 0.5
epoch 1000  Accuracy: 0.5
epoch 1100  Accuracy: 0.40625
epoch 1200  Accuracy: 0.5
epoch 1300  Accuracy: 0.398438
epoch 1400  Accuracy: 0.40625
epoch 1500  Accuracy: 0.703125
epoch 1600  Accuracy: 0.609375
epoch 1700  Accuracy: 0.507812
epoch 1800  Accuracy: 0.703125
epoch 1900  Accuracy: 0.703125
epoch 2000  Accuracy: 0.804688
epoch 2100  Accuracy: 0.703125
epoch 2200  Accuracy: 0.804688
epoch 2300  Accuracy: 0.804688
epoch 2400  Accuracy: 0.804688
epoch 2500  Accuracy: 0.90625
epoch 2600  Accuracy: 0.90625
epoch 2700  Accuracy: 0.90625
epoch 2800  Accuracy: 1
epoch 2900  Accuracy: 1
相關文章
相關標籤/搜索