https://blog.csdn.net/caroline_wendy/article/details/80494120算法
Gluon是MXNet的高層封裝,網絡設計簡單易用,與Keras相似。隨着深度學習技術的普及,相似於Gluon這種,高層封裝的深度學習框架,被愈來愈多的開發者接受和使用。json
在開發深度學習算法時,必然會涉及到網絡(symbol)和參數(params)的存儲與加載,Gluon模型的存取接口,與MXNet略有不一樣。在MXNet體系中,網絡與參數是分離的,這樣的設計,有利於遷移學習(Transfer Learning)中的參數複用。網絡
本文分別介紹MXNet和Gluon中網絡和參數的存取方式。框架
在MXNet體系中,net = symbol + params。
本文地址:https://blog.csdn.net/caroline_wendy/article/details/80494120函數
MXNet學習
MXNet中網絡和參數是分離的,這兩部分須要分別存儲和讀取。.net
網絡設計
MXNet的網絡(symbol)使用json格式存儲:orm
建立填充變量data,即mx.sym.var('data');
將填充變量置入網絡,即net_triplet(vd);
獲取填充以後的網絡結構,轉換爲json對象,即vnet.tojson();
將json對象寫入文件,即write_line(json_file, sym_json)。
則,最終的json文件就是MXNet的網絡結構。對象
實現:
vd = mx.sym.var('data')
vnet = net_triplet(vd)
sym_json = vnet.tojson()
json_file = os.path.join(ROOT_DIR, 'experiments', 'sym.json')
write_line(json_file, sym_json)
1
2
3
4
5
或
sym_json = net_triplet(mx.sym.var('data')).tojson()
json_file = os.path.join(ROOT_DIR, 'experiments', 'sym.json')
write_line(json_file, sym_json)
1
2
3
這種存儲網絡的方式,同時適用於MXNet和Gluon網絡。
參數
MXNet的參數(params)存儲比較簡單:
在訓練過程當中,自動調整網絡的參數;
在訓練過程當中,調用網絡的save_params()函數,便可保存參數。
在參數的文件名中,加入epoch和準確率,有利於參數選擇。
實現:
params_path = os.path.join(
ROOT_DIR, self.config.cp_dir,
"triplet_loss_model_%s_%s.params" % (epoch, '%0.4f' % dist_acc)
) # 模型文件位置
self.model.save_params(params_path) # 存儲模型
1
2
3
4
5
讀取
MXNet網絡和參數的加載方式:
網絡:調用SymbolBlock()建立網絡,output是已加載的Json結構,input是填充的data變量;
參數:調用load_params()加載參數,params是參數路徑,ctx是上下文,即CPU或GPU環境。
實現:
sym = os.path.join(ROOT_DIR, self.config.cp_dir, "sym.json")
params = os.path.join(ROOT_DIR, self.config.cp_dir, "triplet_loss_model_88_0.9934.params")
self.model = gluon.nn.SymbolBlock(outputs=mx.sym.load(sym), inputs=mx.sym.var('data'))
self.model.load_params(params, ctx=ctx)
1
2
3
4
Gluon
Gluon對比與MXNet,提供更加高層的存取方法,簡單高效。
存儲
除了MXNet的存儲方式以外,Gluon網絡提供特定的export()方法,同時支持導出網絡和參數:
輸入:path是文件前綴;epoch是epoch數,支持訓練中屢次保存;
輸出:[前綴]-symbol.json的網絡;[前綴]-[epoch 4位].params的參數;
實現:
symbol_file = os.path.join(ROOT_DIR, self.config.cp_dir, 'triplet-net')
self.model.export(path=symbol_file, epoch=epoch) # gluon的export
1
2
注意:export()方法只能位於訓練階段,不能位於設計階段。
讀取
Gluon支持經過文件前綴(即export()的輸出)的方式,加載網絡與參數:
load_checkpoint(),讀取前綴數據:
輸入:prefix是前綴,epoch是epoch數;
輸出:sym是網絡,arg_params是權重參數,aux_params是輔助狀態;
SymbolBlock(),設置網絡結構,與MXNet相似:
outputs:已加載的Json結構;
inputs:填充的data變量;
設置collect_params()參數,區分:
權重參數,arg_params;
輔助狀態,net_params;
當加載完成網絡和參數以後,就完成了Gluon模型的建立。
實現:
prefix = os.path.join(ROOT_DIR, self.config.cp_dir, "triplet-net") # export導出的前綴
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix=prefix, epoch=5)
net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data')) # 加載網絡結構
# 設置網絡參數
net_params = net.collect_params()
for param in arg_params:
if param in net_params:
net_params[param]._load_init(arg_params[param], ctx=ctx)
for param in aux_params:
if param in net_params:
net_params[param]._load_init(aux_params[param], ctx=ctx)
1
2
3
4
5
6
7
8
9
10
11
錯誤
當出現以下錯誤時,即表示網絡與參數的前綴不一致:
AssertionError: Parameter 'net_conv0_weight' is missing in file 'xxxx.params',
which contains parameters: 'dense0_bias', ..., 'batchnorm2_gamma'.
Please make sure source and target networks have the same prefix.
1
2
3
也就是網絡中的單元名稱與參數中的單元名稱不一樣,前綴不一樣。
解決方案:按照參數中的前綴,統一設置prefix便可,沒有前綴則設置爲空字符串,如:
net_triplet = HybridSequential(prefix='')
1
由於,參數訓練較慢,而網絡容易修改,所以,優先修改網絡的參數名稱。
MXNet網絡的存取方式,也能夠用於Gluon網絡,即Gluon是兼容MXNet的。在MXNet的基礎上,Gluon還在不斷地迭代和完善中,期待更多簡潔的接口,下降深度學習的開發門檻,All with AI。