標籤: BERT 訓練 部署html
在羣裏看到許多朋友在使用BERT模型,網上多數文章只提到了模型的訓練方法,後面的生產部署及調用並無說明。 這段時間使用BERT模型完成了從數據準備到生產部署的全流程,在這裏整理出來,方便你們參考。python
在下面我將以一個「手機評論的情感分類」爲例子,簡要說明從訓練到部署的所有流程。最終完成後可使用一個網頁進行交互,實時地對輸入的評論語句進行分類判斷。git
基本架構爲:github
graph LR A(BERT模型服務端) --> B(API服務端) B-->A B --> C(應用端) C-->B
+-------------------+ | 應用端(HTML) | +-------------------+ ^^ || VV +-------------------+ | API服務端 | +-------------------+ ^^ || VV +-------------------+ | BERT模型服務端 | +-------------------+
架構說明: BERT模型服務端 加載模型,進行實時預測的服務; 使用的是 BERT-BiLSTM-CRF-NER ajax
API服務端 調用實時預測服務,爲應用提供API接口的服務,用flask編寫; chrome
應用端 最終的應用端; 我這裏使用一個HTML網頁來實現;shell
本項目完整源碼地址:BERT從訓練到部署git源碼 項目博客地址: BERT從訓練到部署json
附件: 本例中訓練完成的模型文件.ckpt格式及.pb格式文件,因爲比較大,已放到網盤提供下載:flask
連接:https://pan.baidu.com/s/1DgVjRK7zicbTlAAkFp7nWw 提取碼:8iaw
若是你想跳過前面模型的訓練過程,能夠直接使用訓練好的模型,來完成後面的部署。api
主要包括如下關鍵節點:
這裏用的數據是手機的評論,數據比較簡單,三個分類: -1,0,1 表示負面,中性與正面情感 數據格式以下:
1 手機很好,漂亮時尚,贈品通常 1 手機很好。包裝也很完美,贈品也是收到貨後立刻就發貨了 1 第一次在第三方買的手機 開始很擔憂 不過查一下是正品 很滿意 1 很不錯 續航好 系統流暢 1 不知道真假,相信店家吧 1 快遞挺快的,榮耀10手感仍是不錯的,玩了會王者還不錯,就是先後玻璃, 1 流很快,手機到手感受很酷,白色適合女士,很驚豔!常好,運行速度快,流暢! 1 用了一天才來評價,都還能夠,很滿意 1 幻影藍很好看啊,炫彩系列時尚時尚最時尚,速度快,配送運行?作活動優惠買的,開心? 1 快遞速度快,很贊!軟件更新到最新版。安裝上軟膠保護套拿手上不容易滑落。 0 手機出廠貼膜好薄啊,感受像塑料膜。其餘不能發表 0 用了一段時間,除了手機續航其它還不錯。 0 作工通常 1 挺好的,贊一個,手機很好,很喜歡 0 手機還行,可是手機剛開箱時屏幕和背面有不少指紋痕跡,手機殼跟**在地上磨過似的,好幾條印子。要不是看在能把這些痕跡擦掉,和閒退貨麻煩,就給退了。就不能規規矩矩作生意麼。還有送的都是什麼吊東西,運動手環垃圾一比,貼在手機後面的固定手環還**是塑料的渡了一層銀色,耳機也和圖片描述不符,碎屏險已經註冊,不知道怎麼樣。講真的,要不就別送或者少送,要不,就規規矩矩的,否則到最後還讓人以爲不舒服。其餘沒什麼。 -1 手機總體還能夠,拍照也很清楚,也很流暢支持華爲。給一星是由於有缺陷,送的耳機是壞的!評論區好評太多,須要一些差評來提醒下,之後更加註意細節,提高質量。 0 前天剛買的, 看着還行, 指紋解鎖反應不錯。 1 高端大氣上檔次。 -1 各位小主,注意啦,耳機是沒有的,須要單獨買 0 外觀不錯,感受很耗電啊,在使用段時間評價 1 手機很是好,很好用 -1 沒有發票,圖片與實物不一致 1 習慣在京東採購物品,方便快捷,及時開發票進行報銷,配送員服務也很周到!就是手機收到時沒有電,感受不大正常 1 高端大氣上檔次啊!看電影玩遊戲估計很爽!屏幕夠大!
數據總共8097條,按6:2:2的比例拆分紅train.tsv,test.tsv ,dev.tsv三個數據文件
訓練模型就直接使用BERT的分類方法,把原來的run_classifier.py
複製出來並修改成 run_mobile.py
。關於訓練的代碼網上不少,就不展開說明了,主要有如下方法:
#----------------------------------------- #手機評論情感分類數據處理 2019/3/12 #labels: -1負面 0中性 1正面 class SetimentProcessor(DataProcessor): def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): """See base class.""" """ if not os.path.exists(os.path.join(FLAGS.output_dir, 'label_list.pkl')): with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'wb') as fd: pickle.dump(self.labels, fd) """ return ["-1", "0", "1"] def _create_examples(self, lines, set_type): """Creates examples for the training and dev sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "%s-%s" % (set_type, i) #debug (by xmxoxo) #print("read line: No.%d" % i) text_a = tokenization.convert_to_unicode(line[1]) if set_type == "test": label = "0" else: label = tokenization.convert_to_unicode(line[0]) examples.append( InputExample(guid=guid, text_a=text_a, label=label)) return examples #-----------------------------------------
而後添加一個方法:
processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mrpc": MrpcProcessor, "xnli": XnliProcessor, "setiment": SetimentProcessor, #2019/3/27 add by Echo }
特別說明,這裏有一點要注意,在後期部署的時候,須要一個label2id的字典,因此要在訓練的時候就保存起來,在convert_single_example
這個方法裏增長一段:
#--- save label2id.pkl --- #在這裏輸出label2id.pkl , add by xmxoxo 2019/2/27 output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl") if not os.path.exists(output_label2id_file): with open(output_label2id_file,'wb') as w: pickle.dump(label_map,w) #--- Add end ---
這樣訓練後就會生成這個文件了。
使用如下命令訓練模型,目錄參數請根據各自的狀況修改:
cd /mnt/sda1/transdat/bert-demo/bert/ export BERT_BASE_DIR=/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12 export GLUE_DIR=/mnt/sda1/transdat/bert-demo/bert/data export TRAINED_CLASSIFIER=/mnt/sda1/transdat/bert-demo/bert/output export EXP_NAME=mobile_0 sudo python run_mobile.py \ --task_name=setiment \ --do_train=true \ --do_eval=true \ --data_dir=$GLUE_DIR/$EXP_NAME \ --vocab_file=$BERT_BASE_DIR/vocab.txt \ --bert_config_file=$BERT_BASE_DIR/bert_config.json \ --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ --max_seq_length=128 \ --train_batch_size=32 \ --learning_rate=2e-5 \ --num_train_epochs=5.0 \ --output_dir=$TRAINED_CLASSIFIER/$EXP_NAME
因爲數據比較小,訓練是比較快的,訓練完成後,能夠在輸出目錄獲得模型文件,這裏的模型文件格式是.ckpt的。 訓練結果:
eval_accuracy = 0.861643 eval_f1 = 0.9536328 eval_loss = 0.56324786 eval_precision = 0.9491279 eval_recall = 0.9581805 global_step = 759 loss = 0.5615213
可使用如下語句來進行預測:
sudo python run_mobile.py \ --task_name=setiment \ --do_predict=true \ --data_dir=$GLUE_DIR/$EXP_NAME \ --vocab_file=$BERT_BASE_DIR/vocab.txt \ --bert_config_file=$BERT_BASE_DIR/bert_config.json \ --init_checkpoint=$TRAINED_CLASSIFIER/$EXP_NAME \ --max_seq_length=128 \ --output_dir=$TRAINED_CLASSIFIER/$EXP_NAME
到這裏咱們已經訓練獲得了模型,但這個模型是.ckpt的文件格式,文件比較大,而且有三個文件:
-rw-r--r-- 1 root root 1227239468 Apr 15 17:46 model.ckpt-759.data-00000-of-00001 -rw-r--r-- 1 root root 22717 Apr 15 17:46 model.ckpt-759.index -rw-r--r-- 1 root root 3948381 Apr 15 17:46 model.ckpt-759.meta
能夠看到,模板文件很是大,大約有1.17G。 後面使用的模型服務端,使用的是.pb格式的模型文件,因此須要把生成的ckpt格式模型文件轉換成.pb格式的模型文件。 我這裏提供了一個轉換工具:freeze_graph.py
,使用以下:
usage: freeze_graph.py [-h] -bert_model_dir BERT_MODEL_DIR -model_dir MODEL_DIR [-model_pb_dir MODEL_PB_DIR] [-max_seq_len MAX_SEQ_LEN] [-num_labels NUM_LABELS] [-verbose]
這裏要注意的參數是:
model_dir
就是訓練好的.ckpt文件所在的目錄max_seq_len
要與原來一致;num_labels
是分類標籤的個數,本例中是3個python freeze_graph.py \ -bert_model_dir $BERT_BASE_DIR \ -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \ -max_seq_len 128 \ -num_labels 3
執行成功後能夠看到在model_dir
目錄會生成一個classification_model.pb
文件。 轉爲.pb格式的模型文件,同時也能夠縮小模型文件的大小,能夠看到轉化後的模型文件大約是390M。
-rw-rw-r-- 1 hexi hexi 409326375 Apr 15 17:58 classification_model.pb
如今能夠安裝服務端了,使用的是 bert-base, 來自於項目BERT-BiLSTM-CRF-NER
, 服務端只是該項目中的一個部分。 項目地址:https://github.com/macanv/BERT-BiLSTM-CRF-NER ,感謝Macanv同窗提供這麼好的項目。
這裏要說明一下,咱們常常會看到bert-as-service 這個項目的介紹,它只能加載BERT的預訓練模型,輸出文本向量化的結果。 而若是要加載fine-turing後的模型,就要用到 bert-base 了,詳請請見: 基於BERT預訓練的中文命名實體識別TensorFlow實現
下載代碼並安裝 :
pip install bert-base==0.0.7 -i https://pypi.python.org/simple
或者
git clone https://github.com/macanv/BERT-BiLSTM-CRF-NER cd BERT-BiLSTM-CRF-NER/ python3 setup.py install
使用 bert-base 有三種運行模式,分別支持三種模型,使用參數-mode
來指定:
之因此要分紅不一樣的運行模式,是由於不一樣模型對輸入內容的預處理是不一樣的,命名實體識別NER是要進行序列標註; 而分類模型只要返回label就能夠了。
安裝完後運行服務,同時指定監聽 HTTP 8091端口,並使用GPU 1來跑;
cd /mnt/sda1/transdat/bert-demo/bert/bert_svr export BERT_BASE_DIR=/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12 export TRAINED_CLASSIFIER=/mnt/sda1/transdat/bert-demo/bert/output export EXP_NAME=mobile_0 bert-base-serving-start \ -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \ -bert_model_dir $BERT_BASE_DIR \ -model_pb_dir $TRAINED_CLASSIFIER/$EXP_NAME \ -mode CLASS \ -max_seq_len 128 \ -http_port 8091 \ -port 5575 \ -port_out 5576 \ -device_map 1
注意:port 和 port_out 這兩個參數是API調用的端口號, 默認是5555和5556,若是你準備部署多個模型服務實例,那必定要指定本身的端口號,避免衝突。 我這裏是改成: 5575 和 5576
若是報錯沒運行起來,多是有些模塊沒裝上,都是 bert_base/server/http.py裏引用的,裝上就行了:
sudo pip install flask sudo pip install flask_compress sudo pip install flask_cors sudo pip install flask_json
我這裏的配置是2個GTX 1080 Ti,這個時候雙卡的優點終於發揮出來了,GPU 1用於預測,GPU 0還能夠繼續訓練模型。
運行服務後會自動生成不少臨時的目錄和文件,爲了方便管理與啓動,可創建一個工做目錄,並把啓動命令寫成一個shell腳本。 這裏建立的是mobile_svr\bertsvr.sh
,這樣能夠比較方便地設置服務器啓動時自動啓動服務,另外增長了每次啓動時自動清除臨時文件
代碼以下:
#!/bin/bash #chkconfig: 2345 80 90 #description: 啓動BERT分類模型 echo '正在啓動 BERT mobile svr...' cd /mnt/sda1/transdat/bert-demo/bert/mobile_svr sudo rm -rf tmp* export BERT_BASE_DIR=/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12 export TRAINED_CLASSIFIER=/mnt/sda1/transdat/bert-demo/bert/output export EXP_NAME=mobile_0 bert-base-serving-start \ -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \ -bert_model_dir $BERT_BASE_DIR \ -model_pb_dir $TRAINED_CLASSIFIER/$EXP_NAME \ -mode CLASS \ -max_seq_len 128 \ -http_port 8091 \ -port 5575 \ -port_out 5576 \ -device_map 1
補充說明一下內存的使用狀況: BERT在訓練時須要加載完整的模型數據,要用的內存是比較多的,差很少要10G,我這裏用的是GTX 1080 Ti 11G。 但在訓練完後,按上面的方式部署加載pb模型文件時,就不須要那麼大了,上面也能夠看到pb模型文件就是390M。 其實只要你使用的是BERT base 預訓練模型,最終的獲得的pb文件大小都是差很少的。
還有同窗問到能不能用CPU來部署,我這裏沒嘗試過,但我想確定是能夠的,只是在計算速度上跟GPU會有差異。
我這裏使用GPU 1來實時預測,同時加載了2個BERT模型,截圖以下:
模型服務端部署完成了,可使用curl命令來測試一下它的運行狀況。
curl -X POST http://192.168.15.111:8091/encode \ -H 'content-type: application/json' \ -d '{"id": 111,"texts": ["總的來講,這款手機性價比是特別高的。","槽糕的售後服務!!!店大欺客"], "is_tokenized": false}'
執行結果:
> -H 'content-type: application/json' \ > -d '{"id": 111,"texts": ["總的來講,這款手機性價比是特別高的。","槽糕的售後服務!!!店大欺客"], "is_tokenized": false}' {"id":111,"result":[{"pred_label":["1","-1"],"score":[0.9974544644355774,0.9961422085762024]}],"status":200}
能夠看到對應的兩個評論,預測結果一個是1,另外一個是-1,計算的速度仍是很是很快的。 經過這種方式來調用仍是不太方便,知道了這個通信方式,咱們能夠用flask編寫一個API服務, 爲全部的應用統一提供服務。
爲了方便客戶端的調用,同時也爲了能夠對多個語句進行預測,咱們用flask編寫一個API服務端,使用更簡潔的方式來與客戶端(應用)來通信。 整個API服務端放在獨立目錄/mobile_apisvr/
目錄下。
用flask建立服務端並調用主方法,命令行參數以下:
def main_cli (): pass parser = argparse.ArgumentParser(description='API demo server') parser.add_argument('-ip', type=str, default="0.0.0.0", help='chinese google bert model serving') parser.add_argument('-port', type=int, default=8910, help='listen port,default:8910') args = parser.parse_args() flask_server(args)
主方法裏建立APP對象:
app.run( host = args.ip, #'0.0.0.0', port = args.port, #8910, debug = True )
這裏的接口簡單規劃爲/api/v0.1/query
, 使用POST方法,參數名爲'text',使用JSON返回結果; 路由配置:
@app.route('/api/v0.1/query', methods=['POST'])
API服務端的核心方法,是與BERT-Serving進行通信,須要建立一個客戶端BertClient:
#對句子進行預測識別 def class_pred(list_text): #文本拆分紅句子 #list_text = cut_sent(text) print("total setance: %d" % (len(list_text)) ) with BertClient(ip='192.168.15.111', port=5575, port_out=5576, show_server_config=False, check_version=False, check_length=False,timeout=10000 , mode='CLASS') as bc: start_t = time.perf_counter() rst = bc.encode(list_text) print('result:', rst) print('time used:{}'.format(time.perf_counter() - start_t)) #返回結構爲: # rst: [{'pred_label': ['0', '1', '0'], 'score': [0.9983683228492737, 0.9988993406295776, 0.9997349381446838]}] #抽取出標註結果 pred_label = rst[0]["pred_label"] result_txt = [ [pred_label[i],list_text[i] ] for i in range(len(pred_label))] return result_txt
注意:這裏的IP,端口要與服務端的對應。
運行API 服務端:
python api_service.py
在代碼中的debug設置爲True,這樣只要更新文件,服務就會自動從新啓動,比較方便調試。 運行截圖以下:
到這一步也可使用curl或者其它工具進行測試,也能夠等完成網頁客戶端後一併調試。 我這裏使用chrome插件 API-debug來進行測試,以下圖:
這裏使用一個HTML頁面來模擬客戶端,在實際項目中多是具體的應用。 爲了方便演示就把網頁模板與API服務端合併在一塊兒了,在網頁端使用AJAX來與API服務端通信。
建立模板目錄templates
,使用模板來加載一個HTML,模板文件名爲index.html
。 在HTML頁面裏使用AJAX來調用接口,因爲是在同一個服務器,同一個端口,地址直接用/api/v0.1/query
就能夠了, 在實際項目中,客戶應用端與API是分開的,則須要指定接口URL地址,同時還要注意數據安全性。 代碼以下:
function UrlPOST(txt,myfun){ if (txt=="") { return "error parm"; } var httpurl = "/api/v0.1/query"; $.ajax({ type: "POST", data: "text="+txt, url: httpurl, //async:false, success: function(data) { myfun(data); } }); }
啓動API服務端後,可使用IP+端口
來訪問了,這裏的地址是http://192.168.15.111:8910/
運行界面截圖以下:
能夠看到請求的用時時間爲37ms,速度仍是很快的,固然這個速度跟硬件配置有關。
歡迎批評指正,聯繫郵箱(xmxoxo@qq.com)