用深度學習作命名實體識別(五)-模型使用

經過本文,你將瞭解如何基於訓練好的模型,來編寫一個rest風格的命名實體提取接口,傳入一個句子,接口會提取出句子中的人名、地址、組織、公司、產品、時間信息並返回。html

核心模塊entity_extractor.py

關鍵函數
# 加載實體識別模型
def person_model_init():
   ...
   
# 預測句子中的實體
def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,
            pred_ids,
            tokenizer,
            sess, max_seq_length):
    ...
完整代碼
# -*- coding: utf-8 -*-

"""
基於模型的地址提取
"""
__author__ = '程序員一一滌生'

import codecs
import os
import pickle
from datetime import datetime
from pprint import pprint
import numpy as np
import tensorflow as tf
from bert_base.bert import tokenization, modeling
from bert_base.train.models import create_model, InputFeatures
from bert_base.train.train_helper import get_args_parser

args = get_args_parser()

def convert(line, model_dir, label_list, tokenizer, batch_size, max_seq_length):
    feature = convert_single_example(model_dir, 0, line, label_list, max_seq_length, tokenizer, 'p')
    input_ids = np.reshape([feature.input_ids], (batch_size, max_seq_length))
    input_mask = np.reshape([feature.input_mask], (batch_size, max_seq_length))
    segment_ids = np.reshape([feature.segment_ids], (batch_size, max_seq_length))
    label_ids = np.reshape([feature.label_ids], (batch_size, max_seq_length))
    return input_ids, input_mask, segment_ids, label_ids

def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,
            pred_ids,
            tokenizer,
            sess, max_seq_length):
    with graph.as_default():
        start = datetime.now()
        # print(id2label)
        sentence = tokenizer.tokenize(sentence)
        # print('your input is:{}'.format(sentence))
        input_ids, input_mask, segment_ids, label_ids = convert(sentence, model_dir, label_list, tokenizer, batch_size,
                                                                max_seq_length)

        feed_dict = {input_ids_p: input_ids,
                     input_mask_p: input_mask}
        # run session get current feed_dict result
        pred_ids_result = sess.run([pred_ids], feed_dict)
        pred_label_result = convert_id_to_label(pred_ids_result, id2label, batch_size)
        # print(pred_ids_result)
        print(pred_label_result)
        # todo: 組合策略
        result = strage_combined(sentence, pred_label_result[0], labels_config)
        print('time used: {} sec'.format((datetime.now() - start).total_seconds()))
    return result, pred_label_result

def convert_id_to_label(pred_ids_result, idx2label, batch_size):
    """
    將id形式的結果轉化爲真實序列結果
    :param pred_ids_result:
    :param idx2label:
    :return:
    """
    result = []
    for row in range(batch_size):
        curr_seq = []
        for ids in pred_ids_result[row][0]:
            if ids == 0:
                break
            curr_label = idx2label[ids]
            if curr_label in ['[CLS]', '[SEP]']:
                continue
            curr_seq.append(curr_label)
        result.append(curr_seq)
    return result

def strage_combined(tokens, tags, labels_config):
    """
    組合策略
    :param pred_label_result:
    :param types:
    :return:
    """
    def get_output(rs, data, type):
        words = []
        for i in data:
            words.append(str(i.word).replace("#", ""))
            # words.append(i.word)
        rs[type] = words
        return rs
    eval = Result(labels_config)
    if len(tokens) > len(tags):
        tokens = tokens[:len(tags)]
    labels_dict = eval.get_result(tokens, tags)
    arr = []
    for k, v in labels_dict.items():
        arr.append((k, v))
    rs = {}
    for item in arr:
        rs = get_output(rs, item[1], item[0])
    return rs

def convert_single_example(model_dir, ex_index, example, label_list, max_seq_length, tokenizer, mode):
    """
    將一個樣本進行分析,而後將字轉化爲id, 標籤轉化爲id,而後結構化到InputFeatures對象中
    :param ex_index: index
    :param example: 一個樣本
    :param label_list: 標籤列表
    :param max_seq_length:
    :param tokenizer:
    :param mode:
    :return:
    """
    label_map = {}
    # 1表示從1開始對label進行index化
    for (i, label) in enumerate(label_list, 1):
        label_map[label] = i
    # 保存label->index 的map
    if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):
        with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:
            pickle.dump(label_map, w)
    tokens = example
    # tokens = tokenizer.tokenize(example.text)
    # 序列截斷
    if len(tokens) >= max_seq_length - 1:
        tokens = tokens[0:(max_seq_length - 2)]  # -2 的緣由是由於序列須要加一個句首和句尾標誌
    ntokens = []
    segment_ids = []
    label_ids = []
    ntokens.append("[CLS]")  # 句子開始設置CLS 標誌
    segment_ids.append(0)
    # append("O") or append("[CLS]") not sure!
    label_ids.append(label_map["[CLS]"])  # O OR CLS 沒有任何影響,不過我以爲O 會減小標籤個數,不過拒收和句尾使用不一樣的標誌來標註,使用LCS 也沒毛病
    for i, token in enumerate(tokens):
        ntokens.append(token)
        segment_ids.append(0)
        label_ids.append(0)
    ntokens.append("[SEP]")  # 句尾添加[SEP] 標誌
    segment_ids.append(0)
    # append("O") or append("[SEP]") not sure!
    label_ids.append(label_map["[SEP]"])
    input_ids = tokenizer.convert_tokens_to_ids(ntokens)  # 將序列中的字(ntokens)轉化爲ID形式
    input_mask = [1] * len(input_ids)
    # padding, 使用
    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)
        # we don't concerned about it!
        label_ids.append(0)
        ntokens.append("**NULL**")
        # label_mask.append(0)
    # print(len(input_ids))
    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length
    assert len(label_ids) == max_seq_length
    # assert len(label_mask) == max_seq_length
    # 結構化爲一個類
    feature = InputFeatures(
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        label_ids=label_ids,
        # label_mask = label_mask
    )
    return feature

class Pair(object):
    def __init__(self, word, start, end, type, merge=False):
        self.__word = word
        self.__start = start
        self.__end = end
        self.__merge = merge
        self.__types = type

    @property
    def start(self):
        return self.__start

    @property
    def end(self):
        return self.__end

    @property
    def merge(self):
        return self.__merge

    @property
    def word(self):
        return self.__word

    @property
    def types(self):
        return self.__types

    @word.setter
    def word(self, word):
        self.__word = word

    @start.setter
    def start(self, start):
        self.__start = start

    @end.setter
    def end(self, end):
        self.__end = end

    @merge.setter
    def merge(self, merge):
        self.__merge = merge

    @types.setter
    def types(self, type):
        self.__types = type

    def __str__(self) -> str:
        line = []
        line.append('entity:{}'.format(self.__word))
        line.append('start:{}'.format(self.__start))
        line.append('end:{}'.format(self.__end))
        line.append('merge:{}'.format(self.__merge))
        line.append('types:{}'.format(self.__types))
        return '\t'.join(line)

class Result(object):
    def __init__(self, labels_config):
        self.others = []
        self.labels_config = labels_config
        self.labels = {}
        for la in self.labels_config:
            self.labels[la] = []

    def get_result(self, tokens, tags):
        # 先獲取標註結果
        self.result_to_json(tokens, tags)
        return self.labels

    def result_to_json(self, string, tags):
        """
        將模型標註序列和輸入序列結合 轉化爲結果
        :param string: 輸入序列
        :param tags: 標註結果
        :return:
        """
        item = {"entities": []}
        entity_name = ""
        entity_start = 0
        idx = 0
        last_tag = ''

        for char, tag in zip(string, tags):
            if tag[0] == "S":
                self.append(char, idx, idx + 1, tag[2:])
                item["entities"].append({"word": char, "start": idx, "end": idx + 1, "type": tag[2:]})
            elif tag[0] == "B":
                if entity_name != '':
                    self.append(entity_name, entity_start, idx, last_tag[2:])
                    item["entities"].append(
                        {"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
                    entity_name = ""
                entity_name += char
                entity_start = idx
            elif tag[0] == "I":
                entity_name += char
            elif tag[0] == "O":
                if entity_name != '':
                    self.append(entity_name, entity_start, idx, last_tag[2:])
                    item["entities"].append(
                        {"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
                    entity_name = ""
            else:
                entity_name = ""
                entity_start = idx
            idx += 1
            last_tag = tag
        if entity_name != '':
            self.append(entity_name, entity_start, idx, last_tag[2:])
            item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
        return item

    def append(self, word, start, end, tag):
        if tag in self.labels_config:
            self.labels[tag].append(Pair(word, start, end, tag))
        else:
            self.others.append(Pair(word, start, end, tag))

def person_model_init():
    return model_init("person")

def model_init(model_name):
    if os.name == 'nt':  # windows path config
        model_dir = 'E:/quickstart/deeplearning/nlp_demo/%s/model' % model_name
        bert_dir = 'E:/quickstart/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12'
    else:  # linux path config
        model_dir = '/home/yjy/project/deeplearning/nlp_demo/%s/model' % model_name
        bert_dir = '/home/yjy/project/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12'

    batch_size = 1
    max_seq_length = 500

    print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint")))
    if not os.path.exists(os.path.join(model_dir, "checkpoint")):
        raise Exception("failed to get checkpoint. going to return ")

    # 加載label->id的詞典
    with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf:
        label2id = pickle.load(rf)
        id2label = {value: key for key, value in label2id.items()}

    with codecs.open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:
        label_list = pickle.load(rf)
    num_labels = len(label_list) + 1

    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True
    graph = tf.Graph()
    sess = tf.Session(graph=graph, config=gpu_config)

    with graph.as_default():
        print("going to restore checkpoint")
        # sess.run(tf.global_variables_initializer())
        input_ids_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_ids")
        input_mask_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_mask")

        bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json'))
        (total_loss, logits, trans, pred_ids) = create_model(
            bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p,
            segment_ids=None,
            labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0)

        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(model_dir))

    tokenizer = tokenization.FullTokenizer(
        vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=args.do_lower_case)

    return model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length


if __name__ == "__main__":
    _model_dir, _batch_size, _id2label, _label_list, _graph, _input_ids_p, _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length = person_model_init()
    PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"]
    while True:
        print('input the test sentence:')
        _sentence = str(input())
        pred_rs, pred_label_result = predict(_sentence, PERSON_LABELS, _model_dir, _batch_size, _id2label, _label_list,
                                             _graph,
                                             _input_ids_p,
                                             _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length)
        pprint(pred_rs)

編寫rest風格的接口

咱們將採用python的flask框架來提供rest接口。html5

首先,新建一個python項目,項目根路徑下放入如下目錄和文件:

  • bert_base目錄及文件、bert_model_info目錄及文件在上一篇文章 用深度學習作命名實體識別(四)——模型訓練 給出的雲盤項目中能夠找到;
  • person目錄下的model就是咱們在上一篇文章中訓練獲得的命名實體識別模型以及一些附屬文件,在項目的output目錄下能夠獲得。
而後,建立啓動文件nlp_main.py,內容以下:
# -*- coding: utf-8 -*-

"""
flask 入口
"""
import os
import nlp_config as nc
from flaskr import create_app, loadProjContext

__author__ = '程序員一一滌生'

from flask import jsonify, make_response, redirect

# 加載flask配置信息
# app = create_app('config.DevelopmentConfig')
app = create_app(nc.config['default'])
# 加載項目上下文信息
loadProjContext()

@app.errorhandler(404)
def not_found(error):
    return make_response(jsonify({'error': 'Not found'}), 404)

@app.errorhandler(400)
def not_found(error):
    return make_response(jsonify({'error': '400 Bad Request,參數或參數內容異常'}), 400)

@app.route('/')
def index_sf():
    # return render_template('index.html')
    return redirect('index.html')

if __name__ == '__main__':
    app.run('localhost', 5006, app, use_reloader=False)
接着,建立本flask項目的初始化文件flaskr.py,用於啓動項目的時候預設置和加載一些信息,內容以下:
# -*- coding: utf-8 -*-
"""
flask初始化
"""
from logging.config import dictConfig
from flask import Flask
from flask_cors import CORS
import person_ner_resource
from entity_extractor import person_model_init
from person_ner_resource import person

__author__ = '程序員一一滌生'

def create_app(config_type):
    dictConfig({
        'version': 1,
        'formatters': {'default': {
            'format': '[%(asctime)s] %(name)s %(levelname)s in %(module)s %(lineno)d: %(message)s',
        }},
        'handlers': {'wsgi': {
            'class': 'logging.StreamHandler',
            'stream': 'ext://flask.logging.wsgi_errors_stream',
            'formatter': 'default'
        }},
        'root': {
            'level': 'DEBUG',
            # 'level': 'WARN',
            # 'level': 'INFO',
            'handlers': ['wsgi']
        }
    })
    # 加載flask配置信息
    app = Flask(__name__, static_folder='static', static_url_path='')
    # CORS(app, resources=r'/*',origins=['192.168.1.104'])  # r'/*' 是通配符,容許跨域請求本服務器全部的URL,"origins": '*'表示容許全部ip跨域訪問本服務器的url
    CORS(app, resources={r"/*": {"origins": '*'}})  # r'/*' 是通配符,容許跨域請求本服務器全部的URL,"origins": '*'表示容許全部ip跨域訪問本服務器的url
    app.config.from_object(config_type)
    app.register_blueprint(person, url_prefix='/person')
    # 初始化上下文
    ctx = app.app_context()
    ctx.push()
    return app

def loadProjContext():
    # 加載人名提取模型
    model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = person_model_init()
    person_ner_resource.model_dir = model_dir
    person_ner_resource.batch_size = batch_size
    person_ner_resource.id2label = id2label
    person_ner_resource.label_list = label_list
    person_ner_resource.graph = graph
    person_ner_resource.input_ids_p = input_ids_p
    person_ner_resource.input_mask_p = input_mask_p
    person_ner_resource.pred_ids = pred_ids
    person_ner_resource.tokenizer = tokenizer
    person_ner_resource.sess = sess
    person_ner_resource.max_seq_length = max_seq_length
而後,建立配置文件nlp_config.py,用於切換生產、開發、測試環境,內容以下:
# -*- coding: utf-8 -*-

"""
本模塊是Flask的配置模塊
"""
import os

__author__ = '程序員一一滌生'

basedir = os.path.abspath(os.path.dirname(__file__))

class BaseConfig:  # 基本配置類
    SECRET_KEY = b'\xe4r\x04\xb5\xb2\x00\xf1\xadf\xa3\xf3V\x03\xc5\x9f\x82$^\xa25O\xf0R\xda'
    JSONIFY_MIMETYPE = 'application/json; charset=utf-8'  # 默認JSONIFY_MIMETYPE的配置是不帶'; charset=utf-8的'
    JSON_AS_ASCII = False  # 若不關閉,使用JSONIFY返回json時中文會顯示爲Unicode字符
    ENCODING = 'utf-8'

    # 自定義的配置項
    PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"]

class DevelopmentConfig(BaseConfig):
    ENV = 'development'
    DEBUG = True

class TestingConfig(BaseConfig):
    TESTING = True
    WTF_CSRF_ENABLED = False

class ProductionConfig(BaseConfig):
    DEBUG = False

config = {
    'testing': TestingConfig,
    'default': DevelopmentConfig
    # 'default': ProductionConfig
}
接着,建立人名識別接口文件person_ner_resource.py,內容以下:
# -*- coding: utf-8 -*-

"""
命名實體識別接口
"""
from entity_extractor import predict

__author__ = '程序員一一滌生'

from flask import Blueprint, make_response, request, current_app
from flask import jsonify
person = Blueprint('person', __name__)

model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = None, None, None, None, None, None, None, None, None, None, None
@person.route('/extract', methods=['POST'])

def extract():
    params = request.get_json()
    if 't' not in params or params['t'] is None or len(params['t']) > 500 or len(params['t']) < 2:
        return make_response(jsonify({'error': '文本長度不符合要求,長度限制:2~500'}), 400)
    sentence = params['t']
    # 成句
    sentence = sentence + "。" if not sentence.endswith((",", "。", "!", "?")) else sentence
    # 利用模型提取
    pred_rs, pred_label_result = predict(sentence, current_app.config['PERSON_LABELS'], model_dir, batch_size, id2label,
                                         label_list, graph, input_ids_p,
                                         input_mask_p,
                                         pred_ids, tokenizer, sess, max_seq_length)
    print(sentence)
    return jsonify(pred_rs)

if __name__ == '__main__':
    pass
接着,將requirements.txt文件放到項目根路徑下,文件內容以下:
absl-py==0.7.0
astor==0.7.1
backcall==0.1.0
backports.weakref==1.0rc1
bleach==1.5.0
certifi==2016.2.28
click==6.7
colorama==0.4.1
colorful==0.5.0
decorator==4.3.2
defusedxml==0.5.0
entrypoints==0.3
Flask==1.0.2
Flask-Cors==3.0.3
gast==0.2.2
grpcio==1.18.0
h5py==2.9.0
html5lib==0.9999999
ipykernel==5.1.0
ipython==7.2.0
ipython-genutils==0.2.0
ipywidgets==7.4.2
itsdangerous==0.24
jedi==0.13.2
Jinja2==2.10
jsonschema==2.6.0
jupyter==1.0.0
jupyter-client==5.2.4
jupyter-console==6.0.0
jupyter-core==4.4.0
Keras-Applications==1.0.6
Keras-Preprocessing==1.0.5
Markdown==3.0.1
MarkupSafe==1.1.0
mistune==0.8.4
mock==3.0.5
nbconvert==5.4.0
nbformat==4.4.0
notebook==5.7.4
numpy==1.16.0
pandocfilters==1.4.2
parso==0.3.2
pickleshare==0.7.5
prettyprinter==0.17.0
prometheus-client==0.5.0
prompt-toolkit==2.0.8
protobuf==3.6.1
Pygments==2.3.1
python-dateutil==2.7.5
pywinpty==0.5.5
pyzmq==17.1.2
qtconsole==4.4.3
Send2Trash==1.5.0
six==1.12.0
tensorboard==1.13.1
tensorflow==1.13.1
tensorflow-estimator==1.13.0
termcolor==1.1.0
terminado==0.8.1
testpath==0.4.2
tornado==5.1.1
traitlets==4.3.2
wcwidth==0.1.7
Werkzeug==0.14.1
widgetsnbextension==3.4.2
wincertstore==0.2
而後,執行以下命令,安裝requirements.txt中的包:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt

以上步驟完成後,咱們就能夠嘗試啓動項目了。python

啓動項目

運行以下命令,啓動該flask項目:linux

python nlp_main.py

調用接口

本文使用postman來調用命名實體提取接口,接口地址:git

http://localhost:5006/person/extract程序員

調用效果展現:

算法

注意,在cpu上使用模型的時間大概在2到3秒,而若是項目部署在搭載了支持深度學習的GPU的電腦上,接口的返回會快不少不少,固然不要忘記將tensorflow改成安裝tensorflow-gpu。json

本篇就這麼多內容,到此,咱們已經基於深度學習開發了一個能夠從天然語言中提取出人名、地址、組織、公司、產品、時間的項目,從下一篇開始,咱們將介紹本項目使用的深度學習算法Bertcrf,經過對算法的瞭解,咱們將更好的理解爲何模型可以準確的從句子中提取出咱們想要的實體。flask

ok,本篇就這麼多內容啦~,感謝閱讀O(∩_∩)O,88~windows

本博客內容來自公衆號「程序員一一滌生」,歡迎掃碼關注 o(∩_∩)o

相關文章
相關標籤/搜索