Caffe源碼-SGDSolver類

SGDSolver類簡介

Solver類用於網絡參數的更新,而SGDSolver類實現了優化方法中的隨機梯度降低法(stochastic gradient descent),此外還具有縮放、正則化梯度等功能。caffe中其餘的優化方法都是SGDSolver類的派生類,重載了基類的ComputeUpdateValue()函數,用於各自計算更新的梯度。python

sgd_solver.cpp源碼

// Return the current learning rate. The currently implemented learning rate
// policies are as follows:
//    - fixed: always return base_lr.
//    - step: return base_lr * gamma ^ (floor(iter / step))
//    - exp: return base_lr * gamma ^ iter
//    - inv: return base_lr * (1 + gamma * iter) ^ (- power)
//    - multistep: similar to step but it allows non uniform steps defined by
//      stepvalue
//    - poly: the effective learning rate follows a polynomial decay, to be
//      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
//    - sigmoid: the effective learning rate follows a sigmod decay
//      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
template <typename Dtype>
Dtype SGDSolver<Dtype>::GetLearningRate() {   //根據當前的迭代次數和學習率的更新策略計算並返回當前的學習率
  Dtype rate;
  const string& lr_policy = this->param_.lr_policy();   //獲取學習率的更新策略
  if (lr_policy == "fixed") {                 //每次迭代的學習率爲固定值
    rate = this->param_.base_lr();
  } else if (lr_policy == "step") {           //每隔stepsize_次,當前的學習率乘上係數gamma_
    CHECK_GT(this->param_.stepsize(), 0);
    this->current_step_ = this->iter_ / this->param_.stepsize();  //current_step_爲int類型,爲當前階數
    CHECK_GE(this->param_.gamma(), 0);
    rate = this->param_.base_lr() * pow(this->param_.gamma(), this->current_step_); //lr = base_lr_ * (gamma_ ^ current_step_)
  } else if (lr_policy == "exp") {            //每次迭代,當前的學習率乘上係數gamma_
    CHECK_GE(this->param_.gamma(), 0);
    rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_); //lr = base_lr_ * (gamma_ ^ iter_)
  } else if (lr_policy == "inv") {            //計算公式: lr = base_lr_ * (1 + gamma_ * iter_) ^ (-power_)
    CHECK_GE(this->param_.gamma(), 0);
    rate = this->param_.base_lr() * pow(Dtype(1) + this->param_.gamma() * this->iter_, - this->param_.power());
  } else if (lr_policy == "multistep") {
    //stepvalue_中保存了每一個階段須要的迭代次數,stepvalue_[0] stepvalue_[1] stepvalue_[2] ...
    //當前迭代次數每遞增到一個新的stepvalue_[n]時,當前的學習率乘上係數gamma_
    if (this->current_step_ < this->param_.stepvalue_size() &&
          this->iter_ >= this->param_.stepvalue(this->current_step_)) { //迭代次數遞增到stepvalue_[n]
      this->current_step_++;    //進入下一階段
      LOG(INFO) << "MultiStep Status: Iteration " <<
      this->iter_ << ", step = " << this->current_step_;
    }
    CHECK_GE(this->param_.gamma(), 0);
    rate = this->param_.base_lr() * pow(this->param_.gamma(), this->current_step_);
  } else if (lr_policy == "poly") {     //計算公式: lr = base_lr_ * (1 - iter_ / max_iter_) ^ power_
    rate = this->param_.base_lr() * pow(Dtype(1.) -
        (Dtype(this->iter_) / Dtype(this->param_.max_iter())), this->param_.power());
  } else if (lr_policy == "sigmoid") {  //計算公式: lr = base_lr_ * (1 / (1 + exp(-gamma_ * (iter_ - stepsize_))))
    CHECK_GE(this->param_.gamma(), 0);  //檢查參數的範圍, gamma_ >= 0, stepsize_ > 0
    CHECK_GT(this->param_.stepsize(), 0);
    rate = this->param_.base_lr() * (Dtype(1.) /
        (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - Dtype(this->param_.stepsize())))));
  } else {
    LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
  }
  return rate;
}

//求解器求解以前的預處理操做,清空求解器內部數據,並根據網絡的各個可學習參數blob的大小建立新的空blob
template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
  // Initialize the history
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();  //網絡中全部可學習參數
  history_.clear();   //清空歷史梯度數據,更新數據,臨時數據
  update_.clear();    //三個數據的形狀均與參數blob的形狀一致
  temp_.clear();
  for (int i = 0; i < net_params.size(); ++i) {
    const vector<int>& shape = net_params[i]->shape();   //第i個可學習參數blob的形狀
    history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape))); //使用該形狀建立空blob,保存指針
    update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
    temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
  }
}

//裁剪梯度,參數的梯度數據的l2範數值不能超過設定值clip_gradients,不然會縮放梯度數據
template <typename Dtype>
void SGDSolver<Dtype>::ClipGradients() {
  const Dtype clip_gradients = this->param_.clip_gradients();   //設定的裁剪的閾值
  if (clip_gradients < 0) { return; }   //設定值大於0纔有效,默認-1
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();  //網絡中全部可學習參數
  Dtype sumsq_diff = 0;
  for (int i = 0; i < net_params.size(); ++i) {
    sumsq_diff += net_params[i]->sumsq_diff();    //累加全部參數blob的梯度數據diff_的平方和
  }
  const Dtype l2norm_diff = std::sqrt(sumsq_diff);  //參數梯度的l2範數
  if (l2norm_diff > clip_gradients) {               //大於設定值
    Dtype scale_factor = clip_gradients / l2norm_diff;    //縮放係數
    LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm "
        << l2norm_diff << " > " << clip_gradients << ") "
        << "by scale factor " << scale_factor;      //打印信息
    for (int i = 0; i < net_params.size(); ++i) {
      net_params[i]->scale_diff(scale_factor);      //縮放全部參數blob的梯度數據
    }
  }
}

//根據參數的梯度,網絡的學習率和權重衰減等計算實際更新時的梯度,並更新網絡中的全部可學習參數
template <typename Dtype>
void SGDSolver<Dtype>::ApplyUpdate() {
  Dtype rate = GetLearningRate();   //獲取當前的學習率
  if (this->param_.display() && this->iter_ % this->param_.display() == 0) {  //設置了打印,而且當前迭代次數須要顯示打印信息
    LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << this->iter_
        << ", lr = " << rate;
  }
  ClipGradients();    //裁剪梯度,縮放網絡中全部可學習參數的梯度
  for (int param_id = 0; param_id < this->net_->learnable_params().size(); ++param_id) {
    Normalize(param_id);    //將參數的梯度縮小iter_size倍,獲得單次迭代時可學習參數的平均梯度
    Regularize(param_id);   //施加l1或l2正則化,衰減參數的梯度

    //其餘梯度更新的方法都繼承於SGDSolver類,都實現了各自的ComputeUpdateValue()函數,肯定了用於參數更新的梯度值
    ComputeUpdateValue(param_id, rate); //根據衝量,學習率參數和歷史梯度值更新當前的梯度值
  }
  this->net_->Update();   //使用計算後的梯度值更新網絡中的全部可學習參數, data_ = Dtype(-1) * diff_ + data_

  // Increment the internal iter_ counter -- its value should always indicate
  // the number of times the weights have been updated.
  ++this->iter_;
}

//將net中的第param_id個可學習參數的梯度數據縮小 1/iter_size 倍
//單次迭代會執行iter_size次的前向和反向過程,每次反向過程都會累加梯度,因此須要先縮小
template <typename Dtype>
void SGDSolver<Dtype>::Normalize(int param_id) {
  if (this->param_.iter_size() == 1) { return; }    //iter_size=1就不用此操做了
  // Scale gradient to counterbalance accumulation.
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();  //全部可學習參數
  const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();   // 1/iter_size
  switch (Caffe::mode()) {
  case Caffe::CPU: {  //cpu模式下,net_params[param_id]的diff_數據所有乘上係數 1/iter_size
    caffe_scal(net_params[param_id]->count(), accum_normalization,
        net_params[param_id]->mutable_cpu_diff());
    break;
  }
  case Caffe::GPU: {  //同理,gpu模式下全部參數的diff_也都乘上係數
#ifndef CPU_ONLY
    caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
        net_params[param_id]->mutable_gpu_diff());
#else
    NO_GPU;
#endif
    break;
  }
  default:
    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
  }
}

//將網絡中的第param_id個參數blob進行l1或l2正則化
template <typename Dtype>
void SGDSolver<Dtype>::Regularize(int param_id) {
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();  //全部可學習參數
  const vector<float>& net_params_weight_decay =
      this->net_->params_weight_decay();      //全部可學習參數對應的權重衰減係數
  Dtype weight_decay = this->param_.weight_decay();   //求解器參數中設置的基礎權重衰減值
  string regularization_type = this->param_.regularization_type();  //求解器參數中設置的正則化類型
  Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; //該參數對應的權重衰減值
  switch (Caffe::mode()) {
  case Caffe::CPU: {
    if (local_decay) {    //非0
      if (regularization_type == "L2") {    //l2正則化
        // add weight decay
        //L2正則化會在損失函數中增長項 1/2 * λ * θ^2, 所以計算參數的梯度時,每一個參數的梯度會增長項 λ * θ
        //θ對應參數的data_數據, λ對應參數的權重衰減值local_decay
        caffe_axpy(net_params[param_id]->count(),
            local_decay,
            net_params[param_id]->cpu_data(),
            net_params[param_id]->mutable_cpu_diff());  //公式 diff_ += local_decay * data_
      } else if (regularization_type == "L1") {
        //l1正則化會在損失函數中增長 λ * |θ|, 對應參數的梯度增長 λ * sign(θ). sign(θ)表示參數θ的符號,正(1),負(-1)
        caffe_cpu_sign(net_params[param_id]->count(),
            net_params[param_id]->cpu_data(),
            temp_[param_id]->mutable_cpu_data());   //判斷data_中數據的符號,結果存在臨時變量temp_的data_中
        caffe_axpy(net_params[param_id]->count(),
            local_decay,
            temp_[param_id]->cpu_data(),
            net_params[param_id]->mutable_cpu_diff());  //公式 diff_ += local_decay * sign(data_)
      } else {
        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
      }
    }
    break;
  }
  case Caffe::GPU: {    //如下操做同理,在gpu上實現
#ifndef CPU_ONLY
    if (local_decay) {
      if (regularization_type == "L2") {
        // add weight decay
        caffe_gpu_axpy(net_params[param_id]->count(),
            local_decay,
            net_params[param_id]->gpu_data(),
            net_params[param_id]->mutable_gpu_diff());  // diff_ += local_decay * data_
      } else if (regularization_type == "L1") {
        caffe_gpu_sign(net_params[param_id]->count(),
            net_params[param_id]->gpu_data(),
            temp_[param_id]->mutable_gpu_data());       //temp_data_ = sign(data_)
        caffe_gpu_axpy(net_params[param_id]->count(),
            local_decay,
            temp_[param_id]->gpu_data(),
            net_params[param_id]->mutable_gpu_diff());  //diff_ += local_decay * sign(data_)
      } else {
        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
      }
    }
#else
    NO_GPU;
#endif
    break;
  }
  default:
    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
  }
}

#ifndef CPU_ONLY
template <typename Dtype>
void sgd_update_gpu(int N, Dtype* g, Dtype* h, Dtype momentum,
    Dtype local_rate);    //該函數定義在 sgd_solver.cu 文件中
#endif

//根據衝量參數,學習率參數和歷史梯度數據,更新當前的梯度值
template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();  //網絡中全部參數blob
  const vector<float>& net_params_lr = this->net_->params_lr();   //每一個參數對應的學習率係數
  Dtype momentum = this->param_.momentum();                       //求解器參數中設置的衝量
  Dtype local_rate = rate * net_params_lr[param_id];              //乘上係數,獲得當前參數的學習率
  // Compute the update to history, then copy it to the parameter diff.
  switch (Caffe::mode()) {
  case Caffe::CPU: {
    //計算帶衝量的梯度值,並將梯度保存在history_中,供下次迭代使用
    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
        net_params[param_id]->cpu_diff(), momentum,
        history_[param_id]->mutable_cpu_data());  //history_data = local_rate * param_diff + momentum * history_data
    caffe_copy(net_params[param_id]->count(),
        history_[param_id]->cpu_data(),
        net_params[param_id]->mutable_cpu_diff());  //param_diff = history_data
    break;
  }
  case Caffe::GPU: {
#ifndef CPU_ONLY
    //與cpu操做相似,該函數先是 history_data = local_rate * param_diff + momentum * history_data,
    //再是 param_diff = history_data
    sgd_update_gpu(net_params[param_id]->count(),
        net_params[param_id]->mutable_gpu_diff(),
        history_[param_id]->mutable_gpu_data(),
        momentum, local_rate);
#else
    NO_GPU;
#endif
    break;
  }
  default:
    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
  }
}

//model_filename爲網絡的快照的文件名,將求解器的狀態保存在快照文件中
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverState(const string& model_filename) {
  switch (this->param_.snapshot_format()) {   //快照的格式,二進制proto類型仍是hdf5類型
    case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
      SnapshotSolverStateToBinaryProto(model_filename); //將求解器的狀態存爲二進制proto類型文件
      break;
    case caffe::SolverParameter_SnapshotFormat_HDF5:
      SnapshotSolverStateToHDF5(model_filename);    //將求解器的狀態存爲hdf5類型文件
      break;
    default:
      LOG(FATAL) << "Unsupported snapshot format.";
  }
}

//將SGDSolver的狀態存入SolverState消息中,並存爲文件
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(const string& model_filename) {
  SolverState state;
  state.set_iter(this->iter_);                  //將當前的迭代次數存入SolverState消息中
  state.set_learned_net(model_filename);        //將網絡的快照文件名存入
  state.set_current_step(this->current_step_);  //存入迭代的階段
  state.clear_history();    //清空歷史數據,SolverState消息中的各個參數的歷史數據均爲BlobProto類型的消息
  for (int i = 0; i < history_.size(); ++i) {
    // Add history
    BlobProto* history_blob = state.add_history();  //增長參數的歷史梯度信息
    history_[i]->ToProto(history_blob); //並將求解器中blob類型history_的數據寫入其中
  }
  string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate"); //生成".solverstate"擴展名的快照狀態文件名
  LOG(INFO) << "Snapshotting solver state to binary proto file " << snapshot_filename;  //打印
  WriteProtoToBinaryFile(state, snapshot_filename.c_str()); //將SolverState消息寫入二進制的proto類型文件中
}

//將SGDSolver的iter_/model_filename/current_step_/history_寫入到hdf5文件中
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(const string& model_filename) {
// This code is taken from https://github.com/sh1r0/caffe-android-lib
#ifdef USE_HDF5
  string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate.h5");  //先生成文件名
  LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename;
  hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT); //建立hdf5文件
  CHECK_GE(file_hid, 0) << "Couldn't open " << snapshot_filename << " to save solver state."; //檢查是否建立成功
  hdf5_save_int(file_hid, "iter", this->iter_); //在file_hid中建立名爲"iter"的整形數據集,並將iter_值寫入其中
  hdf5_save_string(file_hid, "learned_net", model_filename);    //建立"learned_net", 並將model_filename寫入其中
  hdf5_save_int(file_hid, "current_step", this->current_step_); //建立"current_step", 並寫入
  hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); //建立"history"組
  CHECK_GE(history_hid, 0) << "Error saving solver state to " << snapshot_filename << ".";
  for (int i = 0; i < history_.size(); ++i) {
    ostringstream oss;
    oss << i;
    hdf5_save_nd_dataset<Dtype>(history_hid, oss.str(), *history_[i]);  //建立Dtype類型的數據集,並將blob中的數據寫入其中
  }
  H5Gclose(history_hid);
  H5Fclose(file_hid);
// This code is taken from https://github.com/sh1r0/caffe-android-lib
#else
  LOG(FATAL) << "SnapshotSolverStateToHDF5 requires hdf5;"
             << " compile with USE_HDF5.";
#endif  // USE_HDF5
}

//從二進制proto文件state_file中讀取求解器的狀態,並存入當前求解器中.若是求解器狀態中還設置了模型參數文件,則還會加載模型參數
template <typename Dtype>
void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
    const string& state_file) {
  SolverState state;
  ReadProtoFromBinaryFile(state_file, &state);  //從state_file文件中讀取消息到state中
  this->iter_ = state.iter();     //使用state中的值設置當前的求解器
  if (state.has_learned_net()) {  //若是設置了模型參數文件的路徑
    NetParameter net_param;
    ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);  //從文件中讀取網絡參數
    this->net_->CopyTrainedLayersFrom(net_param); //數據拷貝至當前網絡中
  }
  this->current_step_ = state.current_step();   //設置
  CHECK_EQ(state.history_size(), history_.size())
      << "Incorrect length of history blobs.";  //檢查state中歷史數據的個數與當前求解器中歷史數據的個數是否匹配
  LOG(INFO) << "SGDSolver: restoring history";
  for (int i = 0; i < history_.size(); ++i) {
    history_[i]->FromProto(state.history(i));   //從state中拷貝歷史梯度數據至當前求解器中
  }
}

//從hdf5文件state_file中讀取求解器的狀態,並存入當前求解器中.若是求解器狀態中還設置了模型參數文件,則還會加載模型參數
template <typename Dtype>
void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
#ifdef USE_HDF5
  hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);  //打開文件
  CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;  //檢查操做是否成功
  this->iter_ = hdf5_load_int(file_hid, "iter");    //從file_hid中讀取"iter"數據集中的整數,存入iter_中
  if (H5LTfind_dataset(file_hid, "learned_net")) {  //判斷file_hid中是否存在名爲"learned_net"的數據集
    //讀取該數據集中的字符串,"learned_net"中存放的是網絡模型參數文件的文件名(**.caffemodel.h5)
    string learned_net = hdf5_load_string(file_hid, "learned_net");
    this->net_->CopyTrainedLayersFrom(learned_net); //讀取模型參數文件,加載網絡參數
  }
  this->current_step_ = hdf5_load_int(file_hid, "current_step");  //讀取"current_step"中的值
  hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT); //打開"history"數據集
  CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
  int state_history_size = hdf5_get_num_links(history_hid);   //獲取其中links(元素)的個數
  CHECK_EQ(state_history_size, history_.size())
      << "Incorrect length of history blobs.";  //一樣檢查是否與當前求解器中的history_匹配
  for (int i = 0; i < history_.size(); ++i) {
    ostringstream oss;
    oss << i;
    hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
                                kMaxBlobAxes, history_[i].get()); //從history_hid中讀取數據,存入history_[i]中
  }
  H5Gclose(history_hid);
  H5Fclose(file_hid);
#else
  LOG(FATAL) << "RestoreSolverStateFromHDF5 requires hdf5;"
             << " compile with USE_HDF5.";
#endif  // USE_HDF5
}

不一樣的學習策略

# 參照代碼中的GetLearningRate()函數,用Python簡單實現了下不一樣學習率策略的效果,方便有個直觀的瞭解
import numpy as np
from math import exp
import matplotlib.pyplot as plt

base_lr = 0.01
max_iter = np.arange(3000)

def fixed(iter):
    return base_lr

def step(iter):
    step_size = 500
    gamma = 0.7
    current_step = int(iter / step_size)
    return base_lr * pow(gamma, current_step)

def exp_policy(iter):
    gamma = 0.99
    return base_lr * pow(gamma, iter)

def inv(iter):
    gamma = 0.001
    power = 0.75
    return base_lr * pow(1 + gamma * iter, -power)

class multistep(object):
    gamma = 0.7
    stepvalue = np.array([200, 800, 1500, 2300])
    multistep_current_step = 0
    def rate(self, iter):
        if (self.multistep_current_step < self.stepvalue.shape[0] and
            iter >= self.stepvalue[self.multistep_current_step]):
            self.multistep_current_step += 1
        return base_lr * pow(self.gamma, self.multistep_current_step)

def poly(iter):
    power = 2
    return base_lr * pow(1 - iter / max_iter.shape[0], power)

def sigmoid(iter):
    gamma = -0.01
    step_size = 1500
    return base_lr * (1 / (1 + exp(-gamma * (iter - step_size))))


rate_fixed = np.array([fixed(iter) for iter in max_iter])
rate_step = np.array([step(iter) for iter in max_iter])
rate_exp_policy = np.array([exp_policy(iter) for iter in max_iter])
rate_inv = np.array([inv(iter) for iter in max_iter])
mltstp = multistep()
rate_multistep = np.array([mltstp.rate(iter) for iter in max_iter])
rate_poly = np.array([poly(iter) for iter in max_iter])
rate_sigmoid = np.array([sigmoid(iter) for iter in max_iter])


plt.figure(1)
ax1 = plt.subplot(3, 3, 1)
ax2 = plt.subplot(3, 3, 2)
ax3 = plt.subplot(3, 3, 3)
ax4 = plt.subplot(3, 3, 4)
ax5 = plt.subplot(3, 3, 5)
ax6 = plt.subplot(3, 3, 6)
ax7 = plt.subplot(3, 3, 7)

plt.sca(ax1)
ax1.set_title('fixed')
plt.plot(max_iter, rate_fixed)
plt.sca(ax2)
ax2.set_title('step')
plt.plot(max_iter, rate_step)
plt.sca(ax3)
ax3.set_title('exp')
plt.plot(max_iter, rate_exp_policy)
plt.sca(ax4)
ax4.set_title('inv')
plt.plot(max_iter, rate_inv)
plt.sca(ax5)
ax5.set_title('multistep')
plt.plot(max_iter, rate_multistep)
plt.sca(ax6)
ax6.set_title('poly')
plt.plot(max_iter, rate_poly)
plt.sca(ax7)
ax7.set_title('sigmoid')
plt.plot(max_iter, rate_sigmoid)
plt.show()

Caffe的源碼筆者是第一次閱讀,一邊閱讀一邊記錄,對代碼的理解和分析可能會存在錯誤或遺漏,但願各位讀者批評指正,謝謝支持!android

相關文章
相關標籤/搜索