dmlc-core是Distributed (Deep) Machine Learning Community的一個基礎模塊,這個模塊用被應用到了mxnet中。dmlc-core在其中用了比軟多的宏技巧,代碼寫得很簡潔,值得你們學習。這博客中講解了其中的宏和mxnet中是怎麼向dmlc-core中註冊函數和初始化參數的。html
C/C++中的宏是編譯的預處理,主要用要文本替換。文本替換就有不少功能,好比用來控制編譯的選項、生成代碼等。在C++沒有被髮明以前,宏的技巧常常會被用於編程中,這些技巧對大部分人來講是難以快速理解的,畢竟代碼是寫給人看的,不是寫給機器看的,因此不少人稱這些爲奇技淫巧。C++出現後,發明了繼承、動態綁定、模板這些現代的面向對象編程概念以後,不少原本用宏技巧寫的代碼被類替換了。但若是宏用得對,可使代碼更加簡潔。python
#define NUM 1024
好比在預處理階段:foo = (int *) malloc (NUM*sizeof(int))
會被替換成foo = (int *) malloc (1024*sizeof(int))
另外,宏體換行須要在行末加反斜槓\
c++
#define ARRAY 1, \ 2, \ 3, \ NUM
好比預處理階段int x[] = { ARRAY }
會被擴展成int x[] = { 1, 2, 3, 1024}
通常狀況下,宏定義所有是大寫字母的,並非說小寫字母不能夠,這只是方便閱讀留下來的習慣,當你們看到全是字母都是大寫時,就會知道,這是一個宏定義。git
#define max(X, Y) ((X) > (Y) ? (X) : (Y))
如在預處理時:a = max(1, 2)
會被擴展成:a = ((1) < (2) ? (1) : (2))
github
#define PRINT(x) \ do{ \ printf("#x = %d \n", x); }\ while(0)
如PRINT(var)
:
會被擴展成:apache
do{ \ printf("var = %d \n", var); }\ while(0)
這種用法能夠用在assert中,能夠直接輸出相關的信息。編程
#define COMMAND(NAME) { #NAME, NAME ## _command } struct command { char *name; void (*function) (void); };
在用到宏的時候的:bash
struct command commands[] = { COMMAND (quit), COMMAND (help), ... };
會被擴展成:數據結構
struct command commands[] = { { "quit", quit_command }, { "help", help_command }, ... };
這樣寫法會比較簡潔,提升了編程的效率。ide
上述的前兩種用法宏的通常用法,後兩種用法則是宏的特殊用法。結果這幾種用法,宏能夠生成不少不少很繞的技巧,好比作遞歸等等。
在上一篇博客——mxnet的訓練過程——從python到C++中提到:「當用C++寫一個新的層時,都要先註冊到mxnet內核dlmc中」。這個註冊就是用宏來實現的,這裏有兩個參考的資料,一個是說了參數的數據結構,只要解讀了parameter.h這個文件,詳見:/dmlc-core/parameter.h;另外一個是說明了參數結構是怎麼工做的Parameter Structure for Machine Learning。這兩個裏面的東西我就不詳細講述了,下面是結合這兩個來講明DMLC-Core宏的工做原理的,對參數結構的描述不如/dmlc-core/parameter.h詳細。全部的代碼來自dmlc-core或者mxnet內的dmlc-core中。
下載並編譯dmlc-core的代碼,編譯出example下載的paramter可執行文件並執行:
git clone https://github.com/dmlc/dmlc-core.git cd dmlc-core make all make example ./example/parameter num_hidden=100 name=aaa activation=relu
執行結果以下:
Docstring --------- num_hidden : int, required Number of hidden unit in the fully connected layer. learning_rate : float, optional, default=0.01 Learning rate of SGD optimization. activation : {'relu', 'sigmoid'}, required Activation function type. name : string, optional, default='mnet' Name of the net. start to set parameters ... ----- param.num_hidden=100 param.learning_rate=0.010000 param.name=aaa param.activation=1
咱們以parameter.cc爲切入點,看DMLC的宏是如何擴展生成代碼的:
struct MyParam : public dmlc::Parameter<MyParam> { float learning_rate; int num_hidden; int activation; std::string name; // declare parameters in header file DMLC_DECLARE_PARAMETER(MyParam) { DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000) .describe("Number of hidden unit in the fully connected layer."); DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f) .describe("Learning rate of SGD optimization."); DMLC_DECLARE_FIELD(activation).add_enum("relu", 1).add_enum("sigmoid", 2) .describe("Activation function type."); DMLC_DECLARE_FIELD(name).set_default("mnet") .describe("Name of the net."); // user can also set nhidden besides num_hidden DMLC_DECLARE_ALIAS(num_hidden, nhidden); DMLC_DECLARE_ALIAS(activation, act); } }; // register it in cc file DMLC_REGISTER_PARAMETER(MyParam);
先看下DMLC_DECLARE_PARAMETER
的定義,這個定義先聲明瞭一個函數____MANAGER__
,但並無定義,第二個是聲明瞭函數__DECLARE__
,定義在上面代碼的第8到第19行,包括在大括號內。__DECLARE__
這個函數體內也有用到了宏。
#define DMLC_DECLARE_PARAMETER(PType) \ static ::dmlc::parameter::ParamManager *__MANAGER__(); \ inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \
要注意的DMLC_DECLARE_FIELD
是隻能用在__DECLARE__
這個函數內的宏,這個宏的定義以下,這個宏返回的是一個對象,.set_range
這些返回的也是對象。DMLC_DECLARE_ALIAS
這個是一個對齊的宏,對齊後能夠兩個名字沒有區別,均可以用。好比DMLC_DECLARE_ALIAS(num_hidden, nhidden)
,那麼num_hidden
與nhidden
是同樣的,以前的運行命令就能夠這樣執行:./example/parameter nhidden=100 name=aaa act=relu
,執行的結果沒有任何區別。
#define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName) #define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName)
相似於DECLARE
這樣的成員函數是定義在父類struct Parameter
中的,以後全部的自義MyParam
都要直接繼承這個父類。AddAlias
這個函數定義在class ParamManager
中,這些函數都在同一個文件parameter.h中。
咱們繼續來看下一個宏DMLC_REGISTER_PARAMETER
,在上一篇博客——mxnet的訓練過程——從python到C++中就提到有一個宏是註冊相關層的到內核中的,這個是註冊到參數到內核中。這個宏的定義如下:
#define DMLC_REGISTER_PARAMETER(PType) \ ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \ static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \ return &inst.manager; \ } \ static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \ __make__ ## PType ## ParamManager__ = \ (*PType::__MANAGER__()) \
這個宏定義了上面聲明的__MANAGER__
,這個函數新建了一個ParamManagerSingleton
的實例,並返回一個ParamManager
的實例。注意到inst
這個變量是用static
修飾的,也就是說inst
(包括他的成員manager
)只會被初始化一次。而且定義了一個全局的manager
,按上面所說的##鏈接法則,這個變量的名字是__make__MyparamParamManager__
。
新建一個ParamManagerSingleton
的實例時,咱們能夠看到它的構造函數調用了上面用宏生成的函數__DECLARE__
,對它的成員manager
中的成員進行了賦值。
template<typename PType> struct ParamManagerSingleton { ParamManager manager; explicit ParamManagerSingleton(const std::string ¶m_name) { PType param; param.__DECLARE__(this); manager.set_name(param_name); } };
咱們來看下主函數:
int main(int argc, char *argv[]) { if (argc == 1) { printf("Usage: [key=value] ...\n"); return 0; } MyParam param; std::map<std::string, std::string> kwargs; for (int i = 0; i < argc; ++i) { char name[256], val[256]; if (sscanf(argv[i], "%[^=]=%[^\n]", name, val) == 2) { kwargs[name] = val; } } printf("Docstring\n---------\n%s", MyParam::__DOC__().c_str()); printf("start to set parameters ...\n"); param.Init(kwargs); printf("-----\n"); printf("param.num_hidden=%d\n", param.num_hidden); printf("param.learning_rate=%f\n", param.learning_rate); printf("param.name=%s\n", param.name.c_str()); printf("param.activation=%d\n", param.activation); return 0; }
這裏中最主要的就是param.Init(kwargs)
,這個是初始化這個變量,__MANAGER__
返回的正是上面生成的__make__MyparamParamManager__
,而後在RunInit
中對字典遍歷,出現的值就賦到相應的位置上,沒有出現的就用默認值,而後再檢查參數是否合法等,找相應該的位置是經過這個MyParam
的頭地址到相應參數的地址的offset來定位的。
template<typename Container> inline void Init(const Container &kwargs, parameter::ParamInitOption option = parameter::kAllowHidden) { PType::__MANAGER__()->RunInit(static_cast<PType*>(this), kwargs.begin(), kwargs.end(), NULL, option); }
在fully_connected.cc用如下的方法來註冊:
MXNET_REGISTER_OP_PROPERTY(FullyConnected, FullyConnectedProp) .describe(R"code(Applies a linear transformation: :math:`Y = XW^T + b`. If ``flatten`` is set to be true, then the shapes are: - **data**: `(batch_size, x1, x2, ..., xn)` - **weight**: `(num_hidden, x1 * x2 * ... * xn)` - **bias**: `(num_hidden,)` - **out**: `(batch_size, num_hidden)` If ``flatten`` is set to be false, then the shapes are: - **data**: `(x1, x2, ..., xn, input_dim)` - **weight**: `(num_hidden, input_dim)` - **bias**: `(num_hidden,)` - **out**: `(x1, x2, ..., xn, num_hidden)` The learnable parameters include both ``weight`` and ``bias``. If ``no_bias`` is set to be true, then the ``bias`` term is ignored. )code" ADD_FILELINE) .add_argument("data", "NDArray-or-Symbol", "Input data.") .add_argument("weight", "NDArray-or-Symbol", "Weight matrix.") .add_argument("bias", "NDArray-or-Symbol", "Bias parameter.") .add_arguments(FullyConnectedParam::__FIELDS__());
宏定義MXNET_REGISTER_OP_PROPERTY
以下:
#define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \ DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \ .set_body([]() { return new OperatorPropertyType(); }) \ .set_return_type("NDArray-or-Symbol") \ .check_name() #define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \ static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ ::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
第二個宏的一樣有關鍵字static
,說明註冊只發生一次。咱們只要看一下::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name)
這個函數,函數Get()
在如下的宏被定義,這個宏在operator.ccDMLC_REGISTRY_ENABLE(::mxnet::OperatorPropertyReg)
運行了。能夠看到這個宏裏一樣有關鍵字static
說明生成的獲得的Registry
是同一個。
#define DMLC_REGISTRY_ENABLE(EntryType) \ template<> \ Registry<EntryType > *Registry<EntryType >::Get() { \ static Registry<EntryType > inst; \ return &inst; \ }
再來看__REGISTER__(#Name)
,這個函數是向獲得的同一個Registry
的成員變量fmap_
寫入名字,並返回一個相關對象。這樣就向內核中註冊了一個函數,能夠看到在上一篇博客——mxnet的訓練過程——從python到C++提到的動態加載函數,就是經過遍歷Registry
中的成員來獲取全部的函數。
inline EntryType &__REGISTER__(const std::string& name) { CHECK_EQ(fmap_.count(name), 0U) << name << " already registered"; EntryType *e = new EntryType(); e->name = name; fmap_[name] = e; const_list_.push_back(e); entry_list_.push_back(e); return *e; }
【防止爬蟲轉載而致使的格式問題——連接】:
http://www.cnblogs.com/heguanyou/p/7613191.html