在文章NLP(十五)讓模型來告訴你文本中的時間中,咱們已經學會了如何利用kashgari模塊來完成序列標註模型的訓練與預測,在本文中,咱們將會了解如何tensorflow-serving來部署模型。
在kashgari的官方文檔中,已經有如何利用tensorflow-serving來部署模型的說明了,網址爲:https://kashgari.bmio.net/advance-use/tensorflow-serving/ 。
下面,本文將介紹tensorflow-serving以及如何利用tensorflow-serving來部署kashgari的模型。python
TensorFlow Serving 是一個用於機器學習模型 serving 的高性能開源庫。它能夠將訓練好的機器學習模型部署到線上,使用 gRPC 做爲接口接受外部調用。更加讓人眼前一亮的是,它支持模型熱更新與自動模型版本管理。這意味着一旦部署 TensorFlow Serving 後,你不再須要爲線上服務操心,只須要關心你的線下模型訓練。
TensorFlow Serving能夠方便咱們部署TensorFlow模型,本文將使用TensorFlow Serving的Docker鏡像來使用TensorFlow Serving,安裝的命令以下:git
docker pull tensorflow/serving
本項目將演示如何利用tensorflow/serving來部署kashgari中的模型,項目結構以下:github
本項目的data來自以前筆者標註的時間數據集,即標註出文本中的時間,採用BIO標註系統。chinese_wwm_ext文件夾爲哈工大的預訓練模型文件。
model_train.py爲模型訓練的代碼,主要功能是完成時間序列標註模型的訓練,完整的代碼以下:web
# -*- coding: utf-8 -*- # time: 2019-09-12 # place: Huangcun Beijing import kashgari from kashgari import utils from kashgari.corpus import DataReader from kashgari.embeddings import BERTEmbedding from kashgari.tasks.labeling import BiLSTM_CRF_Model # 模型訓練 train_x, train_y = DataReader().read_conll_format_file('./data/time.train') valid_x, valid_y = DataReader().read_conll_format_file('./data/time.dev') test_x, test_y = DataReader().read_conll_format_file('./data/time.test') bert_embedding = BERTEmbedding('chinese_wwm_ext_L-12_H-768_A-12', task=kashgari.LABELING, sequence_length=128) model = BiLSTM_CRF_Model(bert_embedding) model.fit(train_x, train_y, valid_x, valid_y, batch_size=16, epochs=1) # Save model utils.convert_to_saved_model(model, model_path='saved_model/time_entity', version=1)
運行該代碼,模型訓練完後會生成saved_model文件夾,裏面含有模型訓練好後的文件,方便咱們利用tensorflow/serving進行部署。接着咱們利用tensorflow/serving來完成模型的部署,命令以下:docker
docker run -t --rm -p 8501:8501 -v "/Users/jclian/PycharmProjects/kashgari_tf_serving/saved_model:/models/" -e MODEL_NAME=time_entity tensorflow/serving
其中須要注意該模型所在的路徑,路徑須要寫完整路徑,以及模型的名稱(MODEL_NAME),這在訓練代碼(train.py)中已經給出(saved_model/time_entity)。json
接着咱們使用tornado來搭建HTTP服務,幫助咱們方便地進行模型預測,runServer.py的完整代碼以下:網絡
# -*- coding: utf-8 -*- import requests from kashgari import utils import numpy as np from model_predict import get_predict import json import tornado.httpserver import tornado.ioloop import tornado.options import tornado.web from tornado.options import define, options import traceback # tornado高併發 import tornado.web import tornado.gen import tornado.concurrent from concurrent.futures import ThreadPoolExecutor # 定義端口爲12333 define("port", default=16016, help="run on the given port", type=int) # 模型預測 class ModelPredictHandler(tornado.web.RequestHandler): executor = ThreadPoolExecutor(max_workers=5) # get 函數 @tornado.gen.coroutine def get(self): origin_text = self.get_argument('text') result = yield self.function(origin_text) self.write(json.dumps(result, ensure_ascii=False)) @tornado.concurrent.run_on_executor def function(self, text): try: text = text.replace(' ', '') x = [_ for _ in text] # Pre-processor data processor = utils.load_processor(model_path='saved_model/time_entity/1') tensor = processor.process_x_dataset([x]) # only for bert Embedding tensor = [{ "Input-Token:0": i.tolist(), "Input-Segment:0": np.zeros(i.shape).tolist() } for i in tensor] # predict r = requests.post("http://localhost:8501/v1/models/time_entity:predict", json={"instances": tensor}) preds = r.json()['predictions'] # Convert result back to labels labels = processor.reverse_numerize_label_sequences(np.array(preds).argmax(-1)) entities = get_predict('TIME', text, labels[0]) return entities except Exception: self.write(traceback.format_exc().replace('\n', '<br>')) # get請求 class HelloHandler(tornado.web.RequestHandler): def get(self): self.write('Hello from lmj from Daxing Beijing!') # 主函數 def main(): # 開啓tornado服務 tornado.options.parse_command_line() # 定義app app = tornado.web.Application( handlers=[(r'/model_predict', ModelPredictHandler), (r'/hello', HelloHandler), ], #網頁路徑控制 ) http_server = tornado.httpserver.HTTPServer(app) http_server.listen(options.port) tornado.ioloop.IOLoop.instance().start() main()
咱們定義了tornado封裝HTTP服務來進行模型預測,運行該腳本,啓動模型預測的HTTP服務。接着咱們再使用Python腳本才測試下模型的預測效果以及預測時間,預測的代碼腳本的完整代碼以下:併發
import time import json import requests t1 = time.time() texts = ['據《新聞聯播》報道,9月9日至11日,中央紀委書記趙樂際到河北調研。', '記者從國家發展改革委、商務部相關方面獲悉,日前美方已決定對擬於10月1日實施的中國輸美商品加徵關稅措施作出調整,中方支持相關企業從即日起按照市場化原則和WTO規則,自美採購必定數量大豆、豬肉等農產品,國務院關稅稅則委員會將對上述採購予以加徵關稅排除。', '據印度Zee新聞網站12日報道,亞洲新聞國際通信社援引印度軍方消息人士的話說,9月11日的對峙事件發生在靠近班公錯北岸的實際控制線一帶。', '儋州市決定,從9月開始,對城市低保、農村低保、特困供養人員、優撫對象、領取失業保險金人員、建檔立卡未脫貧人口等低收入羣體共3萬多人,發放豬肉價格補貼,每人每個月發放不低於100元補貼,之後發放標準,將根據豬肉價波動狀況進行動態調整。', '9月11日,華爲心聲社區發佈美國經濟學家托馬斯.弗裏德曼在《紐約時報》上的專欄內容,弗裏德曼透露,在與華爲創始人任正非最近一次採訪中,任正非表示華爲願意與美國司法部展開話題不設限的討論。', '造血幹細胞移植治療白血病技術已日益成熟,然而,經過該方法同時治癒艾滋病目前仍是一道全球尚在攻克的難題。', '英國航空事故調查局(AAIB)近日披露,今年2月6日一趟由德國法蘭克福飛往墨西哥坎昆的航班上,因飛行員打翻咖啡使操做面板冒煙,致使飛機折返迫降愛爾蘭。', '當地時間週四(9月12日),印度尼西亞財政部長英卓華(Sri Mulyani Indrawati)明確表示:特朗普的推特是風險之一。', '華中科技大學9月12日經過其官方網站發佈通報稱,9月2日,我校一碩士研究生不幸墜樓身亡。', '微博用戶@ooooviki 9月12日下午公佈發生在本身身上的驚悚遭遇:一個自稱網警、名叫鄭洋的人利用職務之便,查到她的完備的我的信息,包括但不限於身份證號、家庭地址、電話號碼、戶籍變更狀況等,要求她作他女友。', '今天,貴陽取消了汽車限購,成爲目前全國實行限購政策的9個省市中,首個取消限購的城市。', '據悉,與全球同步,中國區這次將於9月13日於iPhone官方渠道和京東正式開啓預售,京東成Apple中國區惟一官方受權預售渠道。', '根據央行公佈的數據,截至2019年6月末,存款類金融機構住戶部門短時間消費貸款規模爲9.11萬億元,2019年上半年該項淨增3293.19億元,上半年增量看起來並不樂觀。', '9月11日,一段拍攝浙江萬里學院學生食堂的視頻走紅網絡,視頻顯示該學校食堂不只在用餐區域設置了能夠看電影、比賽的大屏幕,還推出了「一人食」餐位。', '當日,在北京舉行的2019年國際籃聯籃球世界盃半決賽中,西班牙隊對陣澳大利亞隊。', ] print(len(texts)) for text in texts: url = 'http://localhost:16016/model_predict?text=%s' % text req = requests.get(url) print(json.loads(req.content)) t2 = time.time() print(round(t2-t1, 4))
運行該代碼,輸出的結果以下:(預測文本中的時間)app
一共預測15個句子。 ['9月9日至11日'] ['日前', '10月1日', '即日'] ['12日', '9月11日'] ['9月'] ['9月11日'] [] ['近日', '今年2月6日'] ['當地時間週四(9月12日)'] ['9月12日', '9月2日'] ['9月12日下午'] ['今天', '目前'] ['9月13日'] ['2019年6月末', '2019年上半年', '上半年'] ['9月11日'] ['當日', '2019年'] 預測耗時: 15.1085s.
模型預測的效果仍是不錯的,但平均每句話的預測時間爲1秒多,模型預測時間仍是稍微偏長,後續筆者將會研究如何縮短模型預測的時間。機器學習
本項目主要是介紹瞭如何利用tensorflow-serving部署kashgari模型,該項目已經上傳至github,地址爲:https://github.com/percent4/tensorflow-serving_4_kashgari 。 至於如何縮短模型預測的時間,筆者還須要再繼續研究,歡迎你們關注~