準備工做,先準備 python 環境,下載 BERT 語言模型python
- Python 3.6 環境
須要安裝kashgari
git
Backend | pypi version desc |
---|---|
TensorFlow 2.x | pip install ‘kashgari>=2.0.0’ coming soon |
TensorFlow 1.14+ | pip install ‘kashgari>=1.0.0,<2.0.0’ current version |
Keras | pip install ‘kashgari<1.0.0’ legacy version |
- BERT, Chinese 中文模型
我選擇的是工大的BERT-wwm-ext
模型
在此感謝上述做者github
數據集準備
from kashgari.corpus import ChineseDailyNerCorpus train_x, train_y = ChineseDailyNerCorpus.load_data('train') valid_x, valid_y = ChineseDailyNerCorpus.load_data('validate') test_x, test_y = ChineseDailyNerCorpus.load_data('test') print(f"train data count: {len(train_x)}") print(f"validate data count: {len(valid_x)}") print(f"test data count: {len(test_x)}") train data count: 20864 validate data count: 2318 test data count: 4636
採用人民日報標註的數據集,格式爲:app
海 O 釣 O 比 O 賽 O 地 O 點 O 在 O 廈 B-LOC 門 I-LOC 與 O 金 B-LOC 門 I-LOC 之 O 間 O 的 O 海 O 域 O 。 O
建立 BERT embedding
import kashgari from kashgari.embeddings import BERTEmbedding bert_embed = BERTEmbedding('chinese_wwm_ext_L-12_H-768_A-12', task=kashgari.LABELING, sequence_length=100)
建立模型並訓練
from kashgari.tasks.labeling import BiLSTM_CRF_Model # 還能夠選擇 `CNN_LSTM_Model`, `BiLSTM_Model`, `BiGRU_Model` 或 `BiGRU_CRF_Model` model = BiLSTM_CRF_Model(bert_embed) model.fit(train_x, train_y, x_validate=valid_x, y_validate=valid_y, epochs=20, batch_size=512) model.save('ner.h5')
模型評估
model.evaluate(test_x, test_y)
BERT + B-LSTM-CRF 模型效果最好。詳細得分以下:ui
precision | recall | f1-score | support |
---|---|---|---|
LOC | 0.9208 | 0.9324 | 0.9266 |
ORG | 0.8728 | 0.8882 | 0.8804 |
PER | 0.9622 | 0.9633 | 0.9627 |
avg / total | 0.9169 | 0.9271 | 0.9220 |
模型使用
# -*- coding: utf-8 -*- import kashgari import re loaded_model = kashgari.utils.load_model('per_ner.h5') def cut_text(text, lenth): textArr = re.findall('.{' + str(lenth) + '}', text) textArr.append(text[(len(textArr) * lenth):]) return textArr def extract_labels(text, ners): ner_reg_list = [] if ners: new_ners = [] for ner in ners: new_ners += ner; for word, tag in zip([char for char in text], new_ners): if tag != 'O': ner_reg_list.append((word, tag)) # 輸出模型的NER識別結果 labels = {} if ner_reg_list: for i, item in enumerate(ner_reg_list): if item[1].startswith('B'): label = "" end = i + 1 while end <= len(ner_reg_list) - 1 and ner_reg_list[end][1].startswith('I'): end += 1 ner_type = item[1].split('-')[1] if ner_type not in labels.keys(): labels[ner_type] = [] label += ''.join([item[0] for item in ner_reg_list[i:end]]) labels[ner_type].append(label) return labels while True: text_input = input('sentence: ') texts = cut_text(text_input, 100) ners = loaded_model.predict([[char for char in text] for text in texts]) print(ners) labels = extract_labels(text_input, ners) print(labels)
參考文獻
Chinese-BERT-wwm:https://github.com/ymcui/Chinese-BERT-wwm
Kashgari:https://github.com/BrikerMan/Kashgarilua