經過本文,你將瞭解如何基於訓練好的模型,來編寫一個rest風格的命名實體提取接口,傳入一個句子,接口會提取出句子中的人名、地址、組織、公司、產品、時間信息並返回。html
# 加載實體識別模型 def person_model_init(): ... # 預測句子中的實體 def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length): ...
# -*- coding: utf-8 -*- """ 基於模型的地址提取 """ __author__ = '程序員一一滌生' import codecs import os import pickle from datetime import datetime from pprint import pprint import numpy as np import tensorflow as tf from bert_base.bert import tokenization, modeling from bert_base.train.models import create_model, InputFeatures from bert_base.train.train_helper import get_args_parser args = get_args_parser() def convert(line, model_dir, label_list, tokenizer, batch_size, max_seq_length): feature = convert_single_example(model_dir, 0, line, label_list, max_seq_length, tokenizer, 'p') input_ids = np.reshape([feature.input_ids], (batch_size, max_seq_length)) input_mask = np.reshape([feature.input_mask], (batch_size, max_seq_length)) segment_ids = np.reshape([feature.segment_ids], (batch_size, max_seq_length)) label_ids = np.reshape([feature.label_ids], (batch_size, max_seq_length)) return input_ids, input_mask, segment_ids, label_ids def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length): with graph.as_default(): start = datetime.now() # print(id2label) sentence = tokenizer.tokenize(sentence) # print('your input is:{}'.format(sentence)) input_ids, input_mask, segment_ids, label_ids = convert(sentence, model_dir, label_list, tokenizer, batch_size, max_seq_length) feed_dict = {input_ids_p: input_ids, input_mask_p: input_mask} # run session get current feed_dict result pred_ids_result = sess.run([pred_ids], feed_dict) pred_label_result = convert_id_to_label(pred_ids_result, id2label, batch_size) # print(pred_ids_result) print(pred_label_result) # todo: 組合策略 result = strage_combined(sentence, pred_label_result[0], labels_config) print('time used: {} sec'.format((datetime.now() - start).total_seconds())) return result, pred_label_result def convert_id_to_label(pred_ids_result, idx2label, batch_size): """ 將id形式的結果轉化爲真實序列結果 :param pred_ids_result: :param idx2label: :return: """ result = [] for row in range(batch_size): curr_seq = [] for ids in pred_ids_result[row][0]: if ids == 0: break curr_label = idx2label[ids] if curr_label in ['[CLS]', '[SEP]']: continue curr_seq.append(curr_label) result.append(curr_seq) return result def strage_combined(tokens, tags, labels_config): """ 組合策略 :param pred_label_result: :param types: :return: """ def get_output(rs, data, type): words = [] for i in data: words.append(str(i.word).replace("#", "")) # words.append(i.word) rs[type] = words return rs eval = Result(labels_config) if len(tokens) > len(tags): tokens = tokens[:len(tags)] labels_dict = eval.get_result(tokens, tags) arr = [] for k, v in labels_dict.items(): arr.append((k, v)) rs = {} for item in arr: rs = get_output(rs, item[1], item[0]) return rs def convert_single_example(model_dir, ex_index, example, label_list, max_seq_length, tokenizer, mode): """ 將一個樣本進行分析,而後將字轉化爲id, 標籤轉化爲id,而後結構化到InputFeatures對象中 :param ex_index: index :param example: 一個樣本 :param label_list: 標籤列表 :param max_seq_length: :param tokenizer: :param mode: :return: """ label_map = {} # 1表示從1開始對label進行index化 for (i, label) in enumerate(label_list, 1): label_map[label] = i # 保存label->index 的map if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')): with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w: pickle.dump(label_map, w) tokens = example # tokens = tokenizer.tokenize(example.text) # 序列截斷 if len(tokens) >= max_seq_length - 1: tokens = tokens[0:(max_seq_length - 2)] # -2 的緣由是由於序列須要加一個句首和句尾標誌 ntokens = [] segment_ids = [] label_ids = [] ntokens.append("[CLS]") # 句子開始設置CLS 標誌 segment_ids.append(0) # append("O") or append("[CLS]") not sure! label_ids.append(label_map["[CLS]"]) # O OR CLS 沒有任何影響,不過我以爲O 會減小標籤個數,不過拒收和句尾使用不一樣的標誌來標註,使用LCS 也沒毛病 for i, token in enumerate(tokens): ntokens.append(token) segment_ids.append(0) label_ids.append(0) ntokens.append("[SEP]") # 句尾添加[SEP] 標誌 segment_ids.append(0) # append("O") or append("[SEP]") not sure! label_ids.append(label_map["[SEP]"]) input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 將序列中的字(ntokens)轉化爲ID形式 input_mask = [1] * len(input_ids) # padding, 使用 while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) # we don't concerned about it! label_ids.append(0) ntokens.append("**NULL**") # label_mask.append(0) # print(len(input_ids)) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length assert len(label_ids) == max_seq_length # assert len(label_mask) == max_seq_length # 結構化爲一個類 feature = InputFeatures( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids, # label_mask = label_mask ) return feature class Pair(object): def __init__(self, word, start, end, type, merge=False): self.__word = word self.__start = start self.__end = end self.__merge = merge self.__types = type @property def start(self): return self.__start @property def end(self): return self.__end @property def merge(self): return self.__merge @property def word(self): return self.__word @property def types(self): return self.__types @word.setter def word(self, word): self.__word = word @start.setter def start(self, start): self.__start = start @end.setter def end(self, end): self.__end = end @merge.setter def merge(self, merge): self.__merge = merge @types.setter def types(self, type): self.__types = type def __str__(self) -> str: line = [] line.append('entity:{}'.format(self.__word)) line.append('start:{}'.format(self.__start)) line.append('end:{}'.format(self.__end)) line.append('merge:{}'.format(self.__merge)) line.append('types:{}'.format(self.__types)) return '\t'.join(line) class Result(object): def __init__(self, labels_config): self.others = [] self.labels_config = labels_config self.labels = {} for la in self.labels_config: self.labels[la] = [] def get_result(self, tokens, tags): # 先獲取標註結果 self.result_to_json(tokens, tags) return self.labels def result_to_json(self, string, tags): """ 將模型標註序列和輸入序列結合 轉化爲結果 :param string: 輸入序列 :param tags: 標註結果 :return: """ item = {"entities": []} entity_name = "" entity_start = 0 idx = 0 last_tag = '' for char, tag in zip(string, tags): if tag[0] == "S": self.append(char, idx, idx + 1, tag[2:]) item["entities"].append({"word": char, "start": idx, "end": idx + 1, "type": tag[2:]}) elif tag[0] == "B": if entity_name != '': self.append(entity_name, entity_start, idx, last_tag[2:]) item["entities"].append( {"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) entity_name = "" entity_name += char entity_start = idx elif tag[0] == "I": entity_name += char elif tag[0] == "O": if entity_name != '': self.append(entity_name, entity_start, idx, last_tag[2:]) item["entities"].append( {"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) entity_name = "" else: entity_name = "" entity_start = idx idx += 1 last_tag = tag if entity_name != '': self.append(entity_name, entity_start, idx, last_tag[2:]) item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) return item def append(self, word, start, end, tag): if tag in self.labels_config: self.labels[tag].append(Pair(word, start, end, tag)) else: self.others.append(Pair(word, start, end, tag)) def person_model_init(): return model_init("person") def model_init(model_name): if os.name == 'nt': # windows path config model_dir = 'E:/quickstart/deeplearning/nlp_demo/%s/model' % model_name bert_dir = 'E:/quickstart/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12' else: # linux path config model_dir = '/home/yjy/project/deeplearning/nlp_demo/%s/model' % model_name bert_dir = '/home/yjy/project/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12' batch_size = 1 max_seq_length = 500 print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint"))) if not os.path.exists(os.path.join(model_dir, "checkpoint")): raise Exception("failed to get checkpoint. going to return ") # 加載label->id的詞典 with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf: label2id = pickle.load(rf) id2label = {value: key for key, value in label2id.items()} with codecs.open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf: label_list = pickle.load(rf) num_labels = len(label_list) + 1 gpu_config = tf.ConfigProto() gpu_config.gpu_options.allow_growth = True graph = tf.Graph() sess = tf.Session(graph=graph, config=gpu_config) with graph.as_default(): print("going to restore checkpoint") # sess.run(tf.global_variables_initializer()) input_ids_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_ids") input_mask_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_mask") bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json')) (total_loss, logits, trans, pred_ids) = create_model( bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None, labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0) saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(model_dir)) tokenizer = tokenization.FullTokenizer( vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=args.do_lower_case) return model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length if __name__ == "__main__": _model_dir, _batch_size, _id2label, _label_list, _graph, _input_ids_p, _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length = person_model_init() PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"] while True: print('input the test sentence:') _sentence = str(input()) pred_rs, pred_label_result = predict(_sentence, PERSON_LABELS, _model_dir, _batch_size, _id2label, _label_list, _graph, _input_ids_p, _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length) pprint(pred_rs)
咱們將採用python的flask框架來提供rest接口。html5
# -*- coding: utf-8 -*- """ flask 入口 """ import os import nlp_config as nc from flaskr import create_app, loadProjContext __author__ = '程序員一一滌生' from flask import jsonify, make_response, redirect # 加載flask配置信息 # app = create_app('config.DevelopmentConfig') app = create_app(nc.config['default']) # 加載項目上下文信息 loadProjContext() @app.errorhandler(404) def not_found(error): return make_response(jsonify({'error': 'Not found'}), 404) @app.errorhandler(400) def not_found(error): return make_response(jsonify({'error': '400 Bad Request,參數或參數內容異常'}), 400) @app.route('/') def index_sf(): # return render_template('index.html') return redirect('index.html') if __name__ == '__main__': app.run('localhost', 5006, app, use_reloader=False)
# -*- coding: utf-8 -*- """ flask初始化 """ from logging.config import dictConfig from flask import Flask from flask_cors import CORS import person_ner_resource from entity_extractor import person_model_init from person_ner_resource import person __author__ = '程序員一一滌生' def create_app(config_type): dictConfig({ 'version': 1, 'formatters': {'default': { 'format': '[%(asctime)s] %(name)s %(levelname)s in %(module)s %(lineno)d: %(message)s', }}, 'handlers': {'wsgi': { 'class': 'logging.StreamHandler', 'stream': 'ext://flask.logging.wsgi_errors_stream', 'formatter': 'default' }}, 'root': { 'level': 'DEBUG', # 'level': 'WARN', # 'level': 'INFO', 'handlers': ['wsgi'] } }) # 加載flask配置信息 app = Flask(__name__, static_folder='static', static_url_path='') # CORS(app, resources=r'/*',origins=['192.168.1.104']) # r'/*' 是通配符,容許跨域請求本服務器全部的URL,"origins": '*'表示容許全部ip跨域訪問本服務器的url CORS(app, resources={r"/*": {"origins": '*'}}) # r'/*' 是通配符,容許跨域請求本服務器全部的URL,"origins": '*'表示容許全部ip跨域訪問本服務器的url app.config.from_object(config_type) app.register_blueprint(person, url_prefix='/person') # 初始化上下文 ctx = app.app_context() ctx.push() return app def loadProjContext(): # 加載人名提取模型 model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = person_model_init() person_ner_resource.model_dir = model_dir person_ner_resource.batch_size = batch_size person_ner_resource.id2label = id2label person_ner_resource.label_list = label_list person_ner_resource.graph = graph person_ner_resource.input_ids_p = input_ids_p person_ner_resource.input_mask_p = input_mask_p person_ner_resource.pred_ids = pred_ids person_ner_resource.tokenizer = tokenizer person_ner_resource.sess = sess person_ner_resource.max_seq_length = max_seq_length
# -*- coding: utf-8 -*- """ 本模塊是Flask的配置模塊 """ import os __author__ = '程序員一一滌生' basedir = os.path.abspath(os.path.dirname(__file__)) class BaseConfig: # 基本配置類 SECRET_KEY = b'\xe4r\x04\xb5\xb2\x00\xf1\xadf\xa3\xf3V\x03\xc5\x9f\x82$^\xa25O\xf0R\xda' JSONIFY_MIMETYPE = 'application/json; charset=utf-8' # 默認JSONIFY_MIMETYPE的配置是不帶'; charset=utf-8的' JSON_AS_ASCII = False # 若不關閉,使用JSONIFY返回json時中文會顯示爲Unicode字符 ENCODING = 'utf-8' # 自定義的配置項 PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"] class DevelopmentConfig(BaseConfig): ENV = 'development' DEBUG = True class TestingConfig(BaseConfig): TESTING = True WTF_CSRF_ENABLED = False class ProductionConfig(BaseConfig): DEBUG = False config = { 'testing': TestingConfig, 'default': DevelopmentConfig # 'default': ProductionConfig }
# -*- coding: utf-8 -*- """ 命名實體識別接口 """ from entity_extractor import predict __author__ = '程序員一一滌生' from flask import Blueprint, make_response, request, current_app from flask import jsonify person = Blueprint('person', __name__) model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = None, None, None, None, None, None, None, None, None, None, None @person.route('/extract', methods=['POST']) def extract(): params = request.get_json() if 't' not in params or params['t'] is None or len(params['t']) > 500 or len(params['t']) < 2: return make_response(jsonify({'error': '文本長度不符合要求,長度限制:2~500'}), 400) sentence = params['t'] # 成句 sentence = sentence + "。" if not sentence.endswith((",", "。", "!", "?")) else sentence # 利用模型提取 pred_rs, pred_label_result = predict(sentence, current_app.config['PERSON_LABELS'], model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length) print(sentence) return jsonify(pred_rs) if __name__ == '__main__': pass
absl-py==0.7.0 astor==0.7.1 backcall==0.1.0 backports.weakref==1.0rc1 bleach==1.5.0 certifi==2016.2.28 click==6.7 colorama==0.4.1 colorful==0.5.0 decorator==4.3.2 defusedxml==0.5.0 entrypoints==0.3 Flask==1.0.2 Flask-Cors==3.0.3 gast==0.2.2 grpcio==1.18.0 h5py==2.9.0 html5lib==0.9999999 ipykernel==5.1.0 ipython==7.2.0 ipython-genutils==0.2.0 ipywidgets==7.4.2 itsdangerous==0.24 jedi==0.13.2 Jinja2==2.10 jsonschema==2.6.0 jupyter==1.0.0 jupyter-client==5.2.4 jupyter-console==6.0.0 jupyter-core==4.4.0 Keras-Applications==1.0.6 Keras-Preprocessing==1.0.5 Markdown==3.0.1 MarkupSafe==1.1.0 mistune==0.8.4 mock==3.0.5 nbconvert==5.4.0 nbformat==4.4.0 notebook==5.7.4 numpy==1.16.0 pandocfilters==1.4.2 parso==0.3.2 pickleshare==0.7.5 prettyprinter==0.17.0 prometheus-client==0.5.0 prompt-toolkit==2.0.8 protobuf==3.6.1 Pygments==2.3.1 python-dateutil==2.7.5 pywinpty==0.5.5 pyzmq==17.1.2 qtconsole==4.4.3 Send2Trash==1.5.0 six==1.12.0 tensorboard==1.13.1 tensorflow==1.13.1 tensorflow-estimator==1.13.0 termcolor==1.1.0 terminado==0.8.1 testpath==0.4.2 tornado==5.1.1 traitlets==4.3.2 wcwidth==0.1.7 Werkzeug==0.14.1 widgetsnbextension==3.4.2 wincertstore==0.2
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
以上步驟完成後,咱們就能夠嘗試啓動項目了。python
運行以下命令,啓動該flask項目:linux
python nlp_main.py
本文使用postman來調用命名實體提取接口,接口地址:git
http://localhost:5006/person/extract程序員
調用效果展現:
算法
注意,在cpu上使用模型的時間大概在2到3秒,而若是項目部署在搭載了支持深度學習的GPU的電腦上,接口的返回會快不少不少,固然不要忘記將tensorflow改成安裝tensorflow-gpu。json
本篇就這麼多內容,到此,咱們已經基於深度學習開發了一個能夠從天然語言中提取出人名、地址、組織、公司、產品、時間的項目,從下一篇開始,咱們將介紹本項目使用的深度學習算法Bert和crf,經過對算法的瞭解,咱們將更好的理解爲何模型可以準確的從句子中提取出咱們想要的實體。flask
ok,本篇就這麼多內容啦~,感謝閱讀O(∩_∩)O,88~windows
本博客內容來自公衆號「程序員一一滌生」,歡迎掃碼關注 o(∩_∩)o