BERT模型從訓練到部署

BERT模型從訓練到部署全流程

標籤: BERT 訓練 部署html

緣起

在羣裏看到許多朋友在使用BERT模型,網上多數文章只提到了模型的訓練方法,後面的生產部署及調用並無說明。 這段時間使用BERT模型完成了從數據準備到生產部署的全流程,在這裏整理出來,方便你們參考。python

在下面我將以一個「手機評論的情感分類」爲例子,簡要說明從訓練到部署的所有流程。最終完成後可使用一個網頁進行交互,實時地對輸入的評論語句進行分類判斷。git

基本架構

基本架構爲:github

graph LR
A(BERT模型服務端) --> B(API服務端)
B-->A
B --> C(應用端)
C-->B
+-------------------+
|   應用端(HTML)     | 
+-------------------+
         ^^
         ||
         VV
+-------------------+
|     API服務端      | 
+-------------------+
         ^^
         ||
         VV
+-------------------+
|  BERT模型服務端    | 
+-------------------+

架構說明: BERT模型服務端     加載模型,進行實時預測的服務;     使用的是 BERT-BiLSTM-CRF-NER ajax

API服務端      調用實時預測服務,爲應用提供API接口的服務,用flask編寫; chrome

應用端     最終的應用端;     我這裏使用一個HTML網頁來實現;shell

本項目完整源碼地址:BERT從訓練到部署git源碼 項目博客地址: BERT從訓練到部署json

附件: 本例中訓練完成的模型文件.ckpt格式及.pb格式文件,因爲比較大,已放到網盤提供下載:flask

連接:https://pan.baidu.com/s/1DgVjRK7zicbTlAAkFp7nWw 
提取碼:8iaw 

若是你想跳過前面模型的訓練過程,能夠直接使用訓練好的模型,來完成後面的部署。api

關鍵節點

主要包括如下關鍵節點:

  • 數據準備
  • 模型訓練
  • 模型格式轉化
  • 服務端部署與啓動
  • API服務編寫與部署
  • 客戶端(網頁端的編寫與部署)

數據準備

這裏用的數據是手機的評論,數據比較簡單,三個分類: -1,0,1 表示負面,中性與正面情感 數據格式以下:

1    手機很好,漂亮時尚,贈品通常
1    手機很好。包裝也很完美,贈品也是收到貨後立刻就發貨了
1    第一次在第三方買的手機 開始很擔憂 不過查一下是正品 很滿意
1    很不錯 續航好 系統流暢
1    不知道真假,相信店家吧
1    快遞挺快的,榮耀10手感仍是不錯的,玩了會王者還不錯,就是先後玻璃,
1    流很快,手機到手感受很酷,白色適合女士,很驚豔!常好,運行速度快,流暢!
1    用了一天才來評價,都還能夠,很滿意
1    幻影藍很好看啊,炫彩系列時尚時尚最時尚,速度快,配送運行?作活動優惠買的,開心?
1    快遞速度快,很贊!軟件更新到最新版。安裝上軟膠保護套拿手上不容易滑落。
0    手機出廠貼膜好薄啊,感受像塑料膜。其餘不能發表
0    用了一段時間,除了手機續航其它還不錯。
0    作工通常
1    挺好的,贊一個,手機很好,很喜歡
0    手機還行,可是手機剛開箱時屏幕和背面有不少指紋痕跡,手機殼跟**在地上磨過似的,好幾條印子。要不是看在能把這些痕跡擦掉,和閒退貨麻煩,就給退了。就不能規規矩矩作生意麼。還有送的都是什麼吊東西,運動手環垃圾一比,貼在手機後面的固定手環還**是塑料的渡了一層銀色,耳機也和圖片描述不符,碎屏險已經註冊,不知道怎麼樣。講真的,要不就別送或者少送,要不,就規規矩矩的,否則到最後還讓人以爲不舒服。其餘沒什麼。
-1    手機總體還能夠,拍照也很清楚,也很流暢支持華爲。給一星是由於有缺陷,送的耳機是壞的!評論區好評太多,須要一些差評來提醒下,之後更加註意細節,提高質量。
0    前天剛買的,  看着還行, 指紋解鎖反應不錯。
1    高端大氣上檔次。
-1    各位小主,注意啦,耳機是沒有的,須要單獨買
0    外觀不錯,感受很耗電啊,在使用段時間評價
1    手機很是好,很好用
-1    沒有發票,圖片與實物不一致
1    習慣在京東採購物品,方便快捷,及時開發票進行報銷,配送員服務也很周到!就是手機收到時沒有電,感受不大正常
1    高端大氣上檔次啊!看電影玩遊戲估計很爽!屏幕夠大!

數據總共8097條,按6:2:2的比例拆分紅train.tsv,test.tsv ,dev.tsv三個數據文件

模型訓練

訓練模型就直接使用BERT的分類方法,把原來的run_classifier.py 複製出來並修改成 run_mobile.py。關於訓練的代碼網上不少,就不展開說明了,主要有如下方法:

#-----------------------------------------
#手機評論情感分類數據處理 2019/3/12 
#labels: -1負面 0中性 1正面
class SetimentProcessor(DataProcessor):
  def get_train_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

  def get_dev_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

  def get_test_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

  def get_labels(self):
    """See base class."""

    """
    if not os.path.exists(os.path.join(FLAGS.output_dir, 'label_list.pkl')):
        with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'wb') as fd:
            pickle.dump(self.labels, fd)
    """
    return ["-1", "0", "1"]

  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
      if i == 0: 
        continue
      guid = "%s-%s" % (set_type, i)

      #debug (by xmxoxo)
      #print("read line: No.%d" % i)

      text_a = tokenization.convert_to_unicode(line[1])
      if set_type == "test":
        label = "0"
      else:
        label = tokenization.convert_to_unicode(line[0])
      examples.append(
          InputExample(guid=guid, text_a=text_a, label=label))
    return examples
#-----------------------------------------

而後添加一個方法:

  processors = {
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
      "setiment": SetimentProcessor, #2019/3/27 add by Echo
  }

特別說明,這裏有一點要注意,在後期部署的時候,須要一個label2id的字典,因此要在訓練的時候就保存起來,在convert_single_example這個方法裏增長一段:

  #--- save label2id.pkl ---
  #在這裏輸出label2id.pkl , add by xmxoxo 2019/2/27
  output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl")
  if not os.path.exists(output_label2id_file):
    with open(output_label2id_file,'wb') as w:
      pickle.dump(label_map,w)

  #--- Add end ---

這樣訓練後就會生成這個文件了。

使用如下命令訓練模型,目錄參數請根據各自的狀況修改:

cd /mnt/sda1/transdat/bert-demo/bert/
export BERT_BASE_DIR=/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12
export GLUE_DIR=/mnt/sda1/transdat/bert-demo/bert/data
export TRAINED_CLASSIFIER=/mnt/sda1/transdat/bert-demo/bert/output
export EXP_NAME=mobile_0

sudo python run_mobile.py \
  --task_name=setiment \
  --do_train=true \
  --do_eval=true \
  --data_dir=$GLUE_DIR/$EXP_NAME \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --num_train_epochs=5.0 \
  --output_dir=$TRAINED_CLASSIFIER/$EXP_NAME

因爲數據比較小,訓練是比較快的,訓練完成後,能夠在輸出目錄獲得模型文件,這裏的模型文件格式是.ckpt的。 訓練結果:

eval_accuracy = 0.861643
eval_f1 = 0.9536328
eval_loss = 0.56324786
eval_precision = 0.9491279
eval_recall = 0.9581805
global_step = 759
loss = 0.5615213

可使用如下語句來進行預測:

sudo python run_mobile.py \
  --task_name=setiment \
  --do_predict=true \
  --data_dir=$GLUE_DIR/$EXP_NAME \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER/$EXP_NAME \
  --max_seq_length=128 \
  --output_dir=$TRAINED_CLASSIFIER/$EXP_NAME

模型格式轉化

到這裏咱們已經訓練獲得了模型,但這個模型是.ckpt的文件格式,文件比較大,而且有三個文件:

-rw-r--r-- 1 root root 1227239468 Apr 15 17:46 model.ckpt-759.data-00000-of-00001
-rw-r--r-- 1 root root      22717 Apr 15 17:46 model.ckpt-759.index
-rw-r--r-- 1 root root    3948381 Apr 15 17:46 model.ckpt-759.meta

能夠看到,模板文件很是大,大約有1.17G。 後面使用的模型服務端,使用的是.pb格式的模型文件,因此須要把生成的ckpt格式模型文件轉換成.pb格式的模型文件。 我這裏提供了一個轉換工具:freeze_graph.py,使用以下:

usage: freeze_graph.py [-h] -bert_model_dir BERT_MODEL_DIR -model_dir
                       MODEL_DIR [-model_pb_dir MODEL_PB_DIR]
                       [-max_seq_len MAX_SEQ_LEN] [-num_labels NUM_LABELS]
                       [-verbose]

這裏要注意的參數是:

  • model_dir 就是訓練好的.ckpt文件所在的目錄
  • max_seq_len 要與原來一致;
  • num_labels 是分類標籤的個數,本例中是3個
python freeze_graph.py \
    -bert_model_dir $BERT_BASE_DIR \
    -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -max_seq_len 128 \
    -num_labels 3

執行成功後能夠看到在model_dir目錄會生成一個classification_model.pb 文件。 轉爲.pb格式的模型文件,同時也能夠縮小模型文件的大小,能夠看到轉化後的模型文件大約是390M。

-rw-rw-r-- 1 hexi hexi 409326375 Apr 15 17:58 classification_model.pb

服務端部署與啓動

如今能夠安裝服務端了,使用的是 bert-base, 來自於項目BERT-BiLSTM-CRF-NER, 服務端只是該項目中的一個部分。 項目地址:https://github.com/macanv/BERT-BiLSTM-CRF-NER ,感謝Macanv同窗提供這麼好的項目。

這裏要說明一下,咱們常常會看到bert-as-service 這個項目的介紹,它只能加載BERT的預訓練模型,輸出文本向量化的結果。 而若是要加載fine-turing後的模型,就要用到 bert-base 了,詳請請見: 基於BERT預訓練的中文命名實體識別TensorFlow實現

下載代碼並安裝 :

pip install bert-base==0.0.7 -i https://pypi.python.org/simple

或者 

git clone https://github.com/macanv/BERT-BiLSTM-CRF-NER
cd BERT-BiLSTM-CRF-NER/
python3 setup.py install

使用 bert-base 有三種運行模式,分別支持三種模型,使用參數-mode 來指定:

  • NER      序列標註類型,好比命名實體識別;
  • CLASS    分類模型,就是本文中使用的模型
  • BERT     這個就是跟bert-as-service 同樣的模式了

之因此要分紅不一樣的運行模式,是由於不一樣模型對輸入內容的預處理是不一樣的,命名實體識別NER是要進行序列標註; 而分類模型只要返回label就能夠了。

安裝完後運行服務,同時指定監聽 HTTP 8091端口,並使用GPU 1來跑;

cd /mnt/sda1/transdat/bert-demo/bert/bert_svr

export BERT_BASE_DIR=/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12
export TRAINED_CLASSIFIER=/mnt/sda1/transdat/bert-demo/bert/output
export EXP_NAME=mobile_0

bert-base-serving-start \
    -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -bert_model_dir $BERT_BASE_DIR \
    -model_pb_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -mode CLASS \
    -max_seq_len 128 \
    -http_port 8091 \
    -port 5575 \
    -port_out 5576 \
    -device_map 1 

注意:port 和 port_out 這兩個參數是API調用的端口號, 默認是5555和5556,若是你準備部署多個模型服務實例,那必定要指定本身的端口號,避免衝突。 我這裏是改成: 5575 和 5576

若是報錯沒運行起來,多是有些模塊沒裝上,都是 bert_base/server/http.py裏引用的,裝上就行了:

sudo pip install flask 
sudo pip install flask_compress
sudo pip install flask_cors
sudo pip install flask_json

我這裏的配置是2個GTX 1080 Ti,這個時候雙卡的優點終於發揮出來了,GPU 1用於預測,GPU 0還能夠繼續訓練模型。

運行服務後會自動生成不少臨時的目錄和文件,爲了方便管理與啓動,可創建一個工做目錄,並把啓動命令寫成一個shell腳本。 這裏建立的是mobile_svr\bertsvr.sh ,這樣能夠比較方便地設置服務器啓動時自動啓動服務,另外增長了每次啓動時自動清除臨時文件

代碼以下:

#!/bin/bash
#chkconfig: 2345 80 90
#description: 啓動BERT分類模型 

echo '正在啓動 BERT mobile svr...'
cd /mnt/sda1/transdat/bert-demo/bert/mobile_svr
sudo rm -rf tmp*

export BERT_BASE_DIR=/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12
export TRAINED_CLASSIFIER=/mnt/sda1/transdat/bert-demo/bert/output
export EXP_NAME=mobile_0

bert-base-serving-start \
    -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -bert_model_dir $BERT_BASE_DIR \
    -model_pb_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -mode CLASS \
    -max_seq_len 128 \
    -http_port 8091 \
    -port 5575 \
    -port_out 5576 \
    -device_map 1 

補充說明一下內存的使用狀況: BERT在訓練時須要加載完整的模型數據,要用的內存是比較多的,差很少要10G,我這裏用的是GTX 1080 Ti 11G。 但在訓練完後,按上面的方式部署加載pb模型文件時,就不須要那麼大了,上面也能夠看到pb模型文件就是390M。 其實只要你使用的是BERT base 預訓練模型,最終的獲得的pb文件大小都是差很少的。

還有同窗問到能不能用CPU來部署,我這裏沒嘗試過,但我想確定是能夠的,只是在計算速度上跟GPU會有差異。

我這裏使用GPU 1來實時預測,同時加載了2個BERT模型,截圖以下:

GPU截圖

端口測試

模型服務端部署完成了,可使用curl命令來測試一下它的運行狀況。

curl -X POST http://192.168.15.111:8091/encode \
  -H 'content-type: application/json' \
  -d '{"id": 111,"texts": ["總的來講,這款手機性價比是特別高的。","槽糕的售後服務!!!店大欺客"], "is_tokenized": false}'

執行結果:

>   -H 'content-type: application/json' \
>   -d '{"id": 111,"texts": ["總的來講,這款手機性價比是特別高的。","槽糕的售後服務!!!店大欺客"], "is_tokenized": false}'
{"id":111,"result":[{"pred_label":["1","-1"],"score":[0.9974544644355774,0.9961422085762024]}],"status":200}

能夠看到對應的兩個評論,預測結果一個是1,另外一個是-1,計算的速度仍是很是很快的。 經過這種方式來調用仍是不太方便,知道了這個通信方式,咱們能夠用flask編寫一個API服務, 爲全部的應用統一提供服務。

API服務編寫與部署

爲了方便客戶端的調用,同時也爲了能夠對多個語句進行預測,咱們用flask編寫一個API服務端,使用更簡潔的方式來與客戶端(應用)來通信。 整個API服務端放在獨立目錄/mobile_apisvr/目錄下。

用flask建立服務端並調用主方法,命令行參數以下:

def main_cli ():
    pass
    parser = argparse.ArgumentParser(description='API demo server')
    parser.add_argument('-ip', type=str, default="0.0.0.0",
                        help='chinese google bert model serving')
    parser.add_argument('-port', type=int, default=8910,
                        help='listen port,default:8910')

    args = parser.parse_args()

    flask_server(args)

主方法裏建立APP對象:

    app.run(
        host = args.ip,     #'0.0.0.0',
        port = args.port,   #8910,  
        debug = True 
    )

這裏的接口簡單規劃爲/api/v0.1/query, 使用POST方法,參數名爲'text',使用JSON返回結果; 路由配置:

@app.route('/api/v0.1/query', methods=['POST'])

API服務端的核心方法,是與BERT-Serving進行通信,須要建立一個客戶端BertClient:

#對句子進行預測識別
def class_pred(list_text):
    #文本拆分紅句子
    #list_text = cut_sent(text)
    print("total setance: %d" % (len(list_text)) )
    with BertClient(ip='192.168.15.111', port=5575, port_out=5576, show_server_config=False, check_version=False, check_length=False,timeout=10000 ,  mode='CLASS') as bc:
        start_t = time.perf_counter()
        rst = bc.encode(list_text)
        print('result:', rst)
        print('time used:{}'.format(time.perf_counter() - start_t))
    #返回結構爲:
    # rst: [{'pred_label': ['0', '1', '0'], 'score': [0.9983683228492737, 0.9988993406295776, 0.9997349381446838]}]
    #抽取出標註結果
    pred_label = rst[0]["pred_label"]
    result_txt = [ [pred_label[i],list_text[i] ] for i in range(len(pred_label))]
    return result_txt

注意:這裏的IP,端口要與服務端的對應。

運行API 服務端:

python api_service.py

在代碼中的debug設置爲True,這樣只要更新文件,服務就會自動從新啓動,比較方便調試。 運行截圖以下:

API服務端

到這一步也可使用curl或者其它工具進行測試,也能夠等完成網頁客戶端後一併調試。 我這裏使用chrome插件 API-debug來進行測試,以下圖:

API測試

客戶端(網頁端)

這裏使用一個HTML頁面來模擬客戶端,在實際項目中多是具體的應用。 爲了方便演示就把網頁模板與API服務端合併在一塊兒了,在網頁端使用AJAX來與API服務端通信。

建立模板目錄templates,使用模板來加載一個HTML,模板文件名爲index.html。 在HTML頁面裏使用AJAX來調用接口,因爲是在同一個服務器,同一個端口,地址直接用/api/v0.1/query就能夠了, 在實際項目中,客戶應用端與API是分開的,則須要指定接口URL地址,同時還要注意數據安全性。 代碼以下:

function UrlPOST(txt,myfun){
    if (txt=="")
    {
        return "error parm"; 
    }
    var httpurl = "/api/v0.1/query"; 
    $.ajax({
            type: "POST",
            data: "text="+txt,
            url: httpurl,
            //async:false,
            success: function(data)
            {   
                myfun(data);
            }
    });
}

啓動API服務端後,可使用IP+端口來訪問了,這裏的地址是http://192.168.15.111:8910/

運行界面截圖以下:

運行界面截圖

能夠看到請求的用時時間爲37ms,速度仍是很快的,固然這個速度跟硬件配置有關。

參考資料:

歡迎批評指正,聯繫郵箱(xmxoxo@qq.com)

相關文章
相關標籤/搜索