本文記錄使用BERT預訓練模型,修改最頂層softmax層,微調幾個epoch,進行文本分類任務。html
BERT源碼
首先BERT源碼來自谷歌官方tensorflow版:https://github.com/google-research/bertpython
注意,這是tensorflow 1.x 版本的。git
BERT預訓練模型
預訓練模型採用哈工大訊飛聯合實驗室推出的WWM(Whole Word Masking)全詞覆蓋預訓練模型,主要考量是BERT對於中文模型來講,是按照字符進行切割,可是注意到BERT隨機mask掉15%的詞,這裏是徹底隨機的,對於中文來講,頗有可能一個詞的某些字被mask掉了,好比說讓我預測這樣一句話:github
原話: 」我今天早上去打羽毛球了,而後又去蒸了桑拿,感受身心愉悅「json
MASK:」我[MASK]天早上去打[MASK]毛球了,而後[MASK]去蒸了[MASK]拿,感受身心[MASK]悅「centos
雖說從統計學意義上來說這樣作依然能夠學得其特徵,但這樣實際上破壞了中文特有的詞結構,那麼全詞覆蓋主要就是針對這個問題,提出一種機制保證在MASK的時候要麼整個詞都不MASK,要麼MASK掉整個詞。服務器
WWM MASK:」我今天早上去打[MASK][MASK][MASK]了,而後又去蒸了[MASK][MASK],感受身心愉悅「app
例子可能舉得不是很恰當,但大概是這個意思,能夠參考這篇文章:函數
https://www.jiqizhixin.com/articles/2019-06-21-01學習
修改源碼
首先看到下下來的項目結構:
能夠看到run_classifier.py文件,這個是咱們須要用的。另外,chinese開頭的文件是咱們的模型地址,data文件是咱們的數據地址,這個每一個人能夠本身設置。
在run_classifier.py文件中,有一個基類DataProcessor類,這個是咱們須要繼承並重寫的:
class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set.""" raise NotImplementedError() def get_dev_examples(self, data_dir): """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() def get_test_examples(self, data_dir): """Gets a collection of `InputExample`s for prediction.""" raise NotImplementedError() def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" with tf.gfile.Open(input_file, "r") as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: lines.append(line) return lines
能夠看到咱們須要實現得到訓練、驗證、測試數據接口,以及得到標籤的接口。
這裏我本身用的一個類。註釋比較詳細,就不解釋了,主要體現了只要能得到數據,不論咱們的文件格式是什麼樣的,均可以,因此不須要專門爲了這個項目去改本身的輸入數據格式。
class StatutesProcessor(DataProcessor): def _read_txt_(self, data_dir, x_file_name, y_file_name): # 定義咱們的讀取方式,個人工程中已經將x文本和y文本分別存入txt文件中,沒有分隔符 # 用gfile讀取,打開一個沒有線程鎖的的文件IO Wrapper # 基本上和python原生的open是同樣的,只是在某些方面更高效一點 with tf.gfile.Open(data_dir + x_file_name, 'r') as f: lines_x = [x.strip() for x in f.readlines()] with tf.gfile.Open(data_dir + y_file_name, 'r') as f: lines_y = [x.strip() for x in f.readlines()] return lines_x, lines_y def get_train_examples(self, data_dir): lines_x, lines_y = self._read_txt_(data_dir, 'train_x.txt', 'train_y.txt') examples = [] for (i, line) in enumerate(zip(lines_x, lines_y)): guid = 'train-%d' % i # 規範輸入編碼 text_a = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[1]) # 這裏有一些特殊的任務,通常任務直接用上面的就行,下面的label操做能夠註釋掉 # 這裏由於y會有多個標籤,這裏按單標籤來作 label = label.strip().split()[0] # 這裏不作匹配任務,text_b爲None examples.append( InputExample(guid=guid, text_a=text_a, label=label) ) return examples def get_dev_examples(self, data_dir): lines_x, lines_y = self._read_txt_(data_dir, 'val_x.txt', 'val_y.txt') examples = [] for (i, line) in enumerate(zip(lines_x, lines_y)): guid = 'train-%d' % i # 規範輸入編碼 text_a = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[1]) label = label.strip().split()[0] # 這裏不作匹配任務,text_b爲None examples.append( InputExample(guid=guid, text_a=text_a, label=label) ) return examples def get_test_examples(self, data_dir): lines_x, lines_y = self._read_txt_(data_dir, 'test_x.txt', 'test_y.txt') examples = [] for (i, line) in enumerate(zip(lines_x, lines_y)): guid = 'train-%d' % i # 規範輸入編碼 text_a = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[1]) label = label.strip().split()[0] # 這裏不作匹配任務,text_b爲None examples.append( InputExample(guid=guid, text_a=text_a, label=label) ) return examples def get_labels(self): # 我事先統計了全部出現的y值,放在了vocab_y.txt裏 # 由於這裏沒有原生的接口,這裏暫時這麼作了,只要保證能讀到全部的類別就好了 with tf.gfile.Open('data/statutes_small/vocab_y.txt', 'r') as f: vocab_y = [x.strip() for x in f.readlines()] return vocab_y
寫好了以後須要更新一下processors列表,在main函數中,最下面一條就是我新加的。
執行訓練微調
python run_classifier.py --data_dir=data/statutes_small/ --task_name=cail2018 --vocab_file=chinese_wwm_ext_L-12_H-768_A-12/vocab.txt --bert_config_file=chinese_wwm_ext_L-12_H-768_A-12/bert_config.json --output_dir=output/ --do_train=true --do_eval=true --init_checkpoint=chinese_wwm_ext_L-12_H-768_A-12/bert_model.ckpt --max_seq_length=200 --train_batch_size=16 --learning_rate=5e-5 --num_train_epoch=3
相信我,寫在一行,這個會有不少小問題,在centos服務器上若是不能按上返回上一條命令,將會很痛苦。。具體參數含義就和參數名是一致的,不須要解釋。
另外,能夠稍稍修改一些東西來動態輸入訓練集上的loss,由於BERT源碼封裝的過高了,因此只能按照這篇文章:http://www.javashuo.com/article/p-oiwzmecc-hq.html裏面講的方法,每100個step輸出一次train loss(就是100個batch),這樣作雖然意義不大,可是能夠看在你的數據集上模型是否是在收斂,方便調整學習率。
在測試集上進行測試
默認test_batch_size = 8
python run_classifier.py --data_dir=data/statutes_small/ --task_name=cail2018 --vocab_file=chinese_wwm_ext_L-12_H-768_A-12/vocab.txt --bert_config_file=chinese_wwm_ext_L-12_H-768_A-12/bert_config.json --output_dir=output/ --do_predict=true --max_seq_length=200
須要注意的是,調用測試接口會在out路徑中生成一個test_results.tsv,這是一個以’\t’爲分隔符的文件,記錄了每一條輸入測試樣例,輸出的每個維度的值(維度數就是類別數目),須要手動作一點操做來獲得最終分類結果,以及計算指標等等。
# 計算測試結果 # 由於原生的predict生成一個test_results.tsv文件,給出了每個sample的每個維度的值 # 卻並無給出具體的類別預測以及指標,這裏再對這個「中間結果手動轉化一下」 def cal_accuracy(rst_file_dir, y_test_dir): rst_contents = pd.read_csv(rst_file_dir, sep='\t', header=None) # value_list: ndarray value_list = rst_contents.values pred = value_list.argmax(axis=1) labels = [] # 這一步是獲取y標籤到id,id到標籤的對應dict,每一個人獲取的方式應該不一致 y2id, id2y = get_y_to_id(vocab_y_dir='../data/statutes_small/vocab_y.txt') with open(y_test_dir, 'r', encoding='utf-8') as f: line = f.readline() while line: # 這裏由於y有多個標籤,我要取第一個標籤,因此要單獨作操做 label = line.strip().split()[0] labels.append(y2id[label]) line = f.readline() labels = np.asarray(labels) # 預測,pred,真實標籤,labels accuracy = metrics.accuracy_score(y_true=labels, y_pred=pred) # 這裏只舉例了accuracy,其餘的指標也相似計算 print(accuracy) def get_y_to_id(vocab_y_dir): # 這裏把全部的y標籤值存在了文件中 y_vocab = open(vocab_y_dir, 'r', encoding='utf-8').read().splitlines() y2idx = {token: idx for idx, token in enumerate(y_vocab)} idx2y = {idx: token for idx, token in enumerate(y_vocab)} return y2idx, idx2y
這部分代碼在classifier/cal_test_matrix.py中。
個人代碼地址:
參考:
https://github.com/google-research/bert
https://www.cnblogs.com/jiangxinyang/p/10241243.html
https://www.jiqizhixin.com/articles/2019-06-21-01
https://arxiv.org/abs/1906.08101