1 // 2 // Created by xijun1 on 2017/12/14. 3 // 4 5 #include <iostream> 6 #include <vector> 7 #include <string> 8 #include <mxnet/mxnet-cpp/MxNetCpp.h> 9 #include <mxnet/mxnet-cpp/op.h> 10 11 namespace mlp{ 12 class MlpNet{ 13 public: 14 static mx_float OutputAccuracy(mx_float* pred, mx_float* target) { 15 int right = 0; 16 for (int i = 0; i < 128; ++i) { 17 float mx_p = pred[i * 10 + 0]; 18 float p_y = 0; 19 for (int j = 0; j < 10; ++j) { 20 if (pred[i * 10 + j] > mx_p) { 21 mx_p = pred[i * 10 + j]; 22 p_y = j; 23 } 24 } 25 if (p_y == target[i]) right++; 26 } 27 return right / 128.0; 28 } 29 30 static void net(){ 31 using mxnet::cpp::Symbol; 32 using mxnet::cpp::NDArray; 33 34 Symbol x = Symbol::Variable("X"); 35 Symbol y = Symbol::Variable("label"); 36 37 std::vector<std::int32_t> shapes({512 , 10}); 38 //定義一個兩層的網絡. wx + b 39 Symbol weight_0 = Symbol::Variable("weight_0"); 40 Symbol biases_0 = Symbol::Variable("biases_0"); 41 42 Symbol fc_0 = mxnet::cpp::FullyConnected("fc_0",x,weight_0,biases_0 43 ,512); 44 45 Symbol output_0 = mxnet::cpp::LeakyReLU("relu_0",fc_0,mxnet::cpp::LeakyReLUActType::kLeaky); 46 47 Symbol weight_1 = Symbol::Variable("weight_1"); 48 Symbol biases_1 = Symbol::Variable("biases_1"); 49 Symbol fc_1 = mxnet::cpp::FullyConnected("fc_1",output_0,weight_1,biases_1,10); 50 Symbol output_1 = mxnet::cpp::LeakyReLU("relu_1",fc_1,mxnet::cpp::LeakyReLUActType::kLeaky); 51 Symbol pred = mxnet::cpp::SoftmaxOutput("softmax",output_1,y); //目標函數,loss函數 52 mxnet::cpp::Context ctx = mxnet::cpp::Context::cpu( 0); 53 54 //定義輸入數據 55 std::shared_ptr< mx_float > aptr_x(new mx_float[128*28] , [](mx_float* aptr_x){ delete [] aptr_x ;}); 56 std::shared_ptr< mx_float > aptr_y(new mx_float[128] , [](mx_float * aptr_y){ delete [] aptr_y ;}); 57 58 //初始化數據 59 for(int i=0 ; i<128 ; i++){ 60 for(int j=0;j<28 ; j++){ 61 //定義x 62 aptr_x.get()[i*28+j]= i % 10 +0.1f; 63 } 64 65 //定義y 66 aptr_y.get()[i]= i % 10; 67 } 68 std::map<std::string, mxnet::cpp::NDArray> args_map; 69 //導入數據 70 NDArray arr_x(mxnet::cpp::Shape(128,28),ctx, false); 71 NDArray arr_y(mxnet::cpp::Shape( 128 ),ctx,false); 72 //將數據轉換到NDArray中 73 arr_x.SyncCopyFromCPU(aptr_x.get(),128*28); 74 arr_x.WaitToRead(); 75 76 arr_y.SyncCopyFromCPU(aptr_y.get(),128); 77 arr_y.WaitToRead(); 78 79 args_map["X"]=arr_x.Slice(0,128).Copy(ctx) ; 80 args_map["label"]=arr_y.Slice(0,128).Copy(ctx); 81 NDArray::WaitAll(); 82 //綁定網絡 83 mxnet::cpp::Executor *executor = pred.SimpleBind(ctx,args_map); 84 //選擇優化器 85 mxnet::cpp::Optimizer *opt = mxnet::cpp::OptimizerRegistry::Find("sgd"); 86 mx_float learning_rate = 0.0001; //學習率 87 mx_float weight_decay = 1e-4; //權重 88 opt->SetParam("momentum", 0.9) 89 ->SetParam("lr", learning_rate) 90 ->SetParam("wd", weight_decay); 91 //定義各個層參數的數組 92 NDArray arr_w_0(mxnet::cpp::Shape(512,28),ctx, false); 93 NDArray arr_b_0(mxnet::cpp::Shape( 512 ),ctx,false); 94 NDArray arr_w_1(mxnet::cpp::Shape(10 , 512 ) , ctx , false); 95 NDArray arr_b_1(mxnet::cpp::Shape( 10 ) , ctx , false); 96 97 //初始化權重參數 98 arr_w_0 = 0.01f; 99 arr_b_1 = 0.01f; 100 arr_w_1 = 0.01f; 101 arr_b_1 = 0.01f; 102 103 //初始化參數 104 executor->arg_dict()["weight_0"]=arr_w_0; 105 executor->arg_dict()["biases_0"]=arr_b_0; 106 executor->arg_dict()["weight_1"]=arr_w_1; 107 executor->arg_dict()["biases_1"]=arr_b_1; 108 109 mxnet::cpp::NDArray::WaitAll(); 110 //訓練 111 std::cout<<" Training "<<std::endl; 112 113 int max_iters = 20000; //最大迭代次數 114 //獲取訓練網絡的參數列表 115 std::vector<std::string> args_name = pred.ListArguments(); 116 for (int iter = 0; iter < max_iters ; ++iter) { 117 executor->Forward(true); 118 executor->Backward(); 119 120 if(iter % 100 == 0){ 121 std::vector<NDArray> & out = executor->outputs; 122 std::shared_ptr<mx_float> tp_x( new mx_float[128*28] , 123 [](mx_float * tp_x){ delete [] tp_x ;}); 124 out[0].SyncCopyToCPU(tp_x.get(),128*10); 125 NDArray::WaitAll(); 126 std::cout<<"epoch "<<iter<<" "<<"Accuracy: "<< OutputAccuracy(tp_x.get() , aptr_y.get())<<std::endl; 127 } 128 //args_name. 129 for(size_t arg_ind=0; arg_ind<args_name.size(); ++arg_ind){ 130 //執行 131 if(args_name[arg_ind]=="X" || args_name[arg_ind]=="label") 132 continue; 133 134 opt->Update(arg_ind,executor->arg_arrays[arg_ind],executor->grad_arrays[arg_ind]); 135 } 136 NDArray::WaitAll(); 137 138 } 139 140 141 } 142 }; 143 } 144 145 int main(int argc , char * argv[]){ 146 mlp::MlpNet::net(); 147 MXNotifyShutdown(); 148 return EXIT_SUCCESS; 149 }
結果:ios
Training epoch 0 Accuracy: 0.09375 epoch 100 Accuracy: 0.304688 epoch 200 Accuracy: 0.195312 epoch 300 Accuracy: 0.203125 epoch 400 Accuracy: 0.304688 epoch 500 Accuracy: 0.296875 epoch 600 Accuracy: 0.304688 epoch 700 Accuracy: 0.304688 epoch 800 Accuracy: 0.398438 epoch 900 Accuracy: 0.5 epoch 1000 Accuracy: 0.5 epoch 1100 Accuracy: 0.40625 epoch 1200 Accuracy: 0.5 epoch 1300 Accuracy: 0.398438 epoch 1400 Accuracy: 0.40625 epoch 1500 Accuracy: 0.703125 epoch 1600 Accuracy: 0.609375 epoch 1700 Accuracy: 0.507812 epoch 1800 Accuracy: 0.703125 epoch 1900 Accuracy: 0.703125 epoch 2000 Accuracy: 0.804688 epoch 2100 Accuracy: 0.703125 epoch 2200 Accuracy: 0.804688 epoch 2300 Accuracy: 0.804688 epoch 2400 Accuracy: 0.804688 epoch 2500 Accuracy: 0.90625 epoch 2600 Accuracy: 0.90625 epoch 2700 Accuracy: 0.90625 epoch 2800 Accuracy: 1 epoch 2900 Accuracy: 1