使用BERT預訓練模型+微調進行文本分類

 

本文記錄使用BERT預訓練模型,修改最頂層softmax層,微調幾個epoch,進行文本分類任務。html

BERT源碼

首先BERT源碼來自谷歌官方tensorflow版:https://github.com/google-research/bertpython

注意,這是tensorflow 1.x 版本的。git

BERT預訓練模型

預訓練模型採用哈工大訊飛聯合實驗室推出的WWM(Whole Word Masking)全詞覆蓋預訓練模型,主要考量是BERT對於中文模型來講,是按照字符進行切割,可是注意到BERT隨機mask掉15%的詞,這裏是徹底隨機的,對於中文來講,頗有可能一個詞的某些字被mask掉了,好比說讓我預測這樣一句話:github

原話: 」我今天早上去打羽毛球了,而後又去蒸了桑拿,感受身心愉悅「json

MASK:」我[MASK]天早上去打[MASK]毛球了,而後[MASK]去蒸了[MASK]拿,感受身心[MASK]悅「centos

雖說從統計學意義上來說這樣作依然能夠學得其特徵,但這樣實際上破壞了中文特有的詞結構,那麼全詞覆蓋主要就是針對這個問題,提出一種機制保證在MASK的時候要麼整個詞都不MASK,要麼MASK掉整個詞。服務器

WWM MASK:」我今天早上去打[MASK][MASK][MASK]了,而後又去蒸了[MASK][MASK],感受身心愉悅「app

例子可能舉得不是很恰當,但大概是這個意思,能夠參考這篇文章:函數

https://www.jiqizhixin.com/articles/2019-06-21-01學習

修改源碼

首先看到下下來的項目結構:

能夠看到run_classifier.py文件,這個是咱們須要用的。另外,chinese開頭的文件是咱們的模型地址,data文件是咱們的數據地址,這個每一個人能夠本身設置。

在run_classifier.py文件中,有一個基類DataProcessor類,這個是咱們須要繼承並重寫的:

class DataProcessor(object):
  """Base class for data converters for sequence classification data sets."""

  def get_train_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the train set."""
    raise NotImplementedError()

  def get_dev_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the dev set."""
    raise NotImplementedError()

  def get_test_examples(self, data_dir):
    """Gets a collection of `InputExample`s for prediction."""
    raise NotImplementedError()

  def get_labels(self):
    """Gets the list of labels for this data set."""
    raise NotImplementedError()

  @classmethod
  def _read_tsv(cls, input_file, quotechar=None):
    """Reads a tab separated value file."""
    with tf.gfile.Open(input_file, "r") as f:
      reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
      lines = []
      for line in reader:
        lines.append(line)
      return lines

  

能夠看到咱們須要實現得到訓練、驗證、測試數據接口,以及得到標籤的接口。

這裏我本身用的一個類。註釋比較詳細,就不解釋了,主要體現了只要能得到數據,不論咱們的文件格式是什麼樣的,均可以,因此不須要專門爲了這個項目去改本身的輸入數據格式。

class StatutesProcessor(DataProcessor):

    def _read_txt_(self, data_dir, x_file_name, y_file_name):
        # 定義咱們的讀取方式,個人工程中已經將x文本和y文本分別存入txt文件中,沒有分隔符
        # 用gfile讀取,打開一個沒有線程鎖的的文件IO Wrapper
        # 基本上和python原生的open是同樣的,只是在某些方面更高效一點
        with tf.gfile.Open(data_dir + x_file_name, 'r') as f:
            lines_x = [x.strip() for x in f.readlines()]
        with tf.gfile.Open(data_dir + y_file_name, 'r') as f:
            lines_y = [x.strip() for x in f.readlines()]
        return lines_x, lines_y

    def get_train_examples(self, data_dir):
        lines_x, lines_y = self._read_txt_(data_dir, 'train_x.txt', 'train_y.txt')
        examples = []
        for (i, line) in enumerate(zip(lines_x, lines_y)):
            guid = 'train-%d' % i
            # 規範輸入編碼
            text_a = tokenization.convert_to_unicode(line[0])
            label = tokenization.convert_to_unicode(line[1])
            # 這裏有一些特殊的任務,通常任務直接用上面的就行,下面的label操做能夠註釋掉
            # 這裏由於y會有多個標籤,這裏按單標籤來作
            label = label.strip().split()[0]

            # 這裏不作匹配任務,text_b爲None
            examples.append(
                InputExample(guid=guid, text_a=text_a, label=label)
            )
        return examples

    def get_dev_examples(self, data_dir):
        lines_x, lines_y = self._read_txt_(data_dir, 'val_x.txt', 'val_y.txt')
        examples = []
        for (i, line) in enumerate(zip(lines_x, lines_y)):
            guid = 'train-%d' % i
            # 規範輸入編碼
            text_a = tokenization.convert_to_unicode(line[0])
            label = tokenization.convert_to_unicode(line[1])
            label = label.strip().split()[0]

            # 這裏不作匹配任務,text_b爲None
            examples.append(
                InputExample(guid=guid, text_a=text_a, label=label)
            )
        return examples

    def get_test_examples(self, data_dir):
        lines_x, lines_y = self._read_txt_(data_dir, 'test_x.txt', 'test_y.txt')
        examples = []
        for (i, line) in enumerate(zip(lines_x, lines_y)):
            guid = 'train-%d' % i
            # 規範輸入編碼
            text_a = tokenization.convert_to_unicode(line[0])
            label = tokenization.convert_to_unicode(line[1])
            label = label.strip().split()[0]

            # 這裏不作匹配任務,text_b爲None
            examples.append(
                InputExample(guid=guid, text_a=text_a, label=label)
            )
        return examples

    def get_labels(self):
        # 我事先統計了全部出現的y值,放在了vocab_y.txt裏
        # 由於這裏沒有原生的接口,這裏暫時這麼作了,只要保證能讀到全部的類別就好了
        with tf.gfile.Open('data/statutes_small/vocab_y.txt', 'r') as f:
            vocab_y = [x.strip() for x in f.readlines()]

        return vocab_y

  

寫好了以後須要更新一下processors列表,在main函數中,最下面一條就是我新加的。

 

執行訓練微調

python run_classifier.py --data_dir=data/statutes_small/ --task_name=cail2018 --vocab_file=chinese_wwm_ext_L-12_H-768_A-12/vocab.txt --bert_config_file=chinese_wwm_ext_L-12_H-768_A-12/bert_config.json --output_dir=output/ --do_train=true --do_eval=true --init_checkpoint=chinese_wwm_ext_L-12_H-768_A-12/bert_model.ckpt --max_seq_length=200 --train_batch_size=16 --learning_rate=5e-5 --num_train_epoch=3

相信我,寫在一行,這個會有不少小問題,在centos服務器上若是不能按上返回上一條命令,將會很痛苦。。具體參數含義就和參數名是一致的,不須要解釋。

另外,能夠稍稍修改一些東西來動態輸入訓練集上的loss,由於BERT源碼封裝的過高了,因此只能按照這篇文章:http://www.javashuo.com/article/p-oiwzmecc-hq.html裏面講的方法,每100個step輸出一次train loss(就是100個batch),這樣作雖然意義不大,可是能夠看在你的數據集上模型是否是在收斂,方便調整學習率。

在測試集上進行測試

默認test_batch_size = 8

python run_classifier.py --data_dir=data/statutes_small/ --task_name=cail2018 --vocab_file=chinese_wwm_ext_L-12_H-768_A-12/vocab.txt --bert_config_file=chinese_wwm_ext_L-12_H-768_A-12/bert_config.json --output_dir=output/ --do_predict=true --max_seq_length=200

須要注意的是,調用測試接口會在out路徑中生成一個test_results.tsv,這是一個以’\t’爲分隔符的文件,記錄了每一條輸入測試樣例,輸出的每個維度的值(維度數就是類別數目),須要手動作一點操做來獲得最終分類結果,以及計算指標等等。

# 計算測試結果
# 由於原生的predict生成一個test_results.tsv文件,給出了每個sample的每個維度的值
# 卻並無給出具體的類別預測以及指標,這裏再對這個「中間結果手動轉化一下」


def cal_accuracy(rst_file_dir, y_test_dir):
    rst_contents = pd.read_csv(rst_file_dir, sep='\t', header=None)
    # value_list: ndarray
    value_list = rst_contents.values
    pred = value_list.argmax(axis=1)
    labels = []

    # 這一步是獲取y標籤到id,id到標籤的對應dict,每一個人獲取的方式應該不一致
    y2id, id2y = get_y_to_id(vocab_y_dir='../data/statutes_small/vocab_y.txt')
    with open(y_test_dir, 'r', encoding='utf-8') as f:
        line = f.readline()
        while line:
            # 這裏由於y有多個標籤,我要取第一個標籤,因此要單獨作操做
            label = line.strip().split()[0]
            labels.append(y2id[label])
            line = f.readline()
    labels = np.asarray(labels)

    # 預測,pred,真實標籤,labels
    accuracy = metrics.accuracy_score(y_true=labels, y_pred=pred)
    # 這裏只舉例了accuracy,其餘的指標也相似計算
    print(accuracy)


def get_y_to_id(vocab_y_dir):
    # 這裏把全部的y標籤值存在了文件中
    y_vocab = open(vocab_y_dir, 'r', encoding='utf-8').read().splitlines()
    y2idx = {token: idx for idx, token in enumerate(y_vocab)}
    idx2y = {idx: token for idx, token in enumerate(y_vocab)}
    return y2idx, idx2y

  

這部分代碼在classifier/cal_test_matrix.py中。

個人代碼地址:

點擊這裏

 

參考:

https://github.com/google-research/bert

https://www.cnblogs.com/jiangxinyang/p/10241243.html

https://www.jiqizhixin.com/articles/2019-06-21-01

https://arxiv.org/abs/1906.08101

相關文章
相關標籤/搜索