基於BERT的超長文本分類模型

基於BERT的超長文本分類模型

 

0.Abstract

本文實現了一個基於BERT+LSTM超長文本分類的模型, 評估方法使用準確率和F1 Score.
項目代碼github地址: https://github.com/neesetifa/bert_classification
git

1.任務介紹

用BERT作文本分類是一個比較常見的項目.
可是衆所周知BERT對於文本輸入長度有限制. 對於超長文本的處理, 最簡單暴力無腦高效的辦法是直接截斷, 就取開頭這部分送入BERT. 可是也請別看不起這種作法, 每每最簡單,最Naive的方法效果反而比一頓操做猛如虎 複雜模型來得好.
github

這裏多提一句爲何. 一般長文本的文章結構都比較明確, 文章前面一兩段基本都是對於後面的概述. 因此等於做者已經幫你提取了文章大意, 因此直接取前面一部分理論上來講是有意義的.
固然也有最新研究代表取文章中間部分效果也很不錯. 在此不展開.
網絡

本文實現的是一種基於HIERARCHICAL(級聯)思想的作法, 把文本切成多片處理. 該方法來自於這篇論文 <Hierarchical Transformers for Long Document Classification>.
文中提到這麼作還能下降self-attention計算的時間複雜度.
假設原句子長爲n, 每一個分段的長度是k. 咱們知道最原始的BERT計算時間複雜度是O(n2), 做者認爲,這麼作能夠把時間複雜度下降到O(nk). 由於咱們把n分數據分割成k小份, 那麼咱們一共要作n/k次, 每次咱們的時間複雜度是k2, 即O(n/k * k2) = O(nk)

app

數據集函數

此次咱們測試該模型在兩種語言上的效果. 分別是中文數據集和英語數據集.
中文數據集依舊是咱們的老朋友ChineseNLPCorps提供的不一樣類別商品的評論.
中文數據集傳送門
英語數據集來源於Kaggle比賽, 用戶對於不一樣金融產品的評論.
英語數據集傳送門
因爲兩種數據集訓練預測上沒有什麼本質區別, 下文會用英語數據集來演示.




學習

評估方法測試

本項目使用的評估方法是準確率和F1 Score. 很是常見的分類問題評價標準.優化

測試集google

此項目中直接取了數據集裏一小部分做爲測試集.spa

2.數據初步處理

數據集裏有55W條數據,18個features.
在這裏插入圖片描述
咱們須要的部分是product(即商品類別)以及consumer complaint narrative.
在這裏插入圖片描述
觀察數據集,咱們發現用戶評論是有NaN值的. 並且本次實驗目的是作超長文本分類. 咱們選取非NaN值,而且是長度大於250的評論.



在這裏插入圖片描述
篩選完後咱們保留大約17k條左右數據
在這裏插入圖片描述

3.Baseline模型

咱們先來看一下什麼都不作, 直接用BERT進行finetune能達到什麼樣的效果. 咱們以此做爲實驗的baseline.
本次預訓練模型使用google官方的BERT-base-cased英語預訓練模型(固然用uncased應該也不要緊, 我沒有測試)
fine-tune部分很簡單, 直接提取[CLS] token後過線性層, 是比較常規的套路. 損失函數使用cross entropy loss.
文本送入的最大長度定爲250. 即前文裏提到的"直接截取文本前面部分". 這次實驗裏咱們嘗試比較HIERARCHICAL方法能比直接截取提升多少.
在這裏插入圖片描述
如圖, 準確率達到了88%. 訓練數據不過10k的數量級, 對於深度學習來講是很是少的. 這裏不得不感嘆下BERT做爲預訓練模型在小樣本數據上的實力很是強勁.




4. 數據進一步處理

接下來咱們進入提升部分. 首先對數據進一步處理.

分割文本

HIERARCHICAL思想本質是對數據進行有重疊(overlap)的分割. 這樣分割後的每句句子之間仍然保留了必定的關聯信息.

衆所周知,BERT輸入的最大長度限制爲512, 其中還須要包括[CLS]和[SEP]. 那麼實際可用的長度僅爲510. 可是別忘了, 每一個單詞tokenizer以後也有可能被分紅好幾部分. 因此實際可輸入的句子長度遠不足510.
本次實驗裏咱們設置分割的長度爲200, overlap長度爲50. 若是實際上線生產確有大量超過500長度的文本, 只需將分割和overlap長度設置更長便可.

def get_split_text(text, split_len=250, overlap_len=50):
	split_text=[]
	for w in range(len(text)//split_len):
  		if w == 0:   #第一次,直接分割長度放進去
    		text_piece = text[:split_len]
  		else:      # 不然, 按照(分割長度-overlap)日後走
  			window = split_len - overlap_len
    		text_piece = [w * window: w * window + split_len]
		split_text.append(text_piece)
	return split_text

分割完後長這樣
在這裏插入圖片描述
隨後咱們將這些分割的句子分離成單獨的一條數據. 併爲他們加上label.
在這裏插入圖片描述
對比原文本能夠發現, index 1~ index4來源於同一句句子. 它被分割成了4份而且每份都擁有原文本的label.



4.最終模型

最終模型由兩個部分構成, 第一部分是和baseline裏如出一轍的, fine-tune後的BERT. 第二部分是由LSTM+FC層組成的混合模型.
即實際上, BERT只是用來提取出句子的表示, 而真正在作分類的是LSTM + FC部分(更準確來講是FC部分, 由於LSTM模型部分仍然在作進一步的特徵提取工做)
這裏稍微提一句,這樣作法我我的認爲相似於廣告推薦系統裏GBDT+LR的組合. 採用一個稍微複雜的模型去作特徵提取, 而後用一個相對簡單的模型去預測.

第一部分: BERT

首先,咱們把分割好後的文本送入BERT進行訓練. 這邊我跑了5個epoch, 顯卡仍然是Tesla K80, 每一個epoch大約須要23分鐘左右.
在這裏插入圖片描述
接着, 咱們提取出這些文本的句子表示.
方便起見, 咱們這裏仍然用[CLS] token做爲句子表示. 固然也能夠用sequence_output(在我上一個項目FAQ問答的最後結論中, 使用sequence_output的確能比pooled_output效果更好一點)
咱們得到的是這樣一組數據:



句子1_a的embedding, label
句子1_b的embedding, label
句子1_c的embedding, label
句子2_a的embedding, label
句子2_b的embedding, label
句子3_a的embedding, label






隨後咱們把這些embedding拼回起來, 變成了

[句子1_a的embedding,句子1_b的embedding, 句子1_c的embedding], label
[句子2_a的embedding, 句子2_b的embedding], label
[句子3_a的embedding, 句子3_b的embedding], label

這部分數據將做爲LSTM部分的輸入.

第二部分: LSTM + FC

這一步,咱們將上一步獲得的embedding直接送入LSTM網絡訓練.

回想一下, 咱們平時用LSTM作, 是否是把句子過了embedding層以後再送入LSTM的? 這裏咱們直接跳過embedding層, 由於咱們的數據自己就是embedding

因爲分割後的embedding都不會太長, 咱們直接使用LSTM最後一個time step的輸出(固然這裏也有個嘗試點, 若是提取出LSTM每一個time step的輸出效果是否是會更好?)
LSTM以後會過一個激活函數, 接一個FC層, FC層和label用cross entropy loss進行優化.
因爲合併後的數據量比較小, 我跑了10個epoch, 每次都很快.
在這裏插入圖片描述


最終效果和一些小節

(左邊loss, 右邊accuracy)
在這裏插入圖片描述
最終效果竟然提升到了94%!! 說實話這個提高量遠高於論文. 可能和數據自己好也有關係. 可是咱們能夠認爲, 比起直接截取文本開頭一段, 採用HIERARCHICAL方式不只克服了BERT長度限制的缺點, 也極大提高了對於超長文本的分類效果.

下面是在中文數據集上模型的baseline效果和提高後的效果.
(待跑)

因此我認爲, 採用HIERARCHICAL方法, 提高/解決了BERT兩方面的缺點:
1.下降了BERT裏self-attention部分計算的時間複雜度. 就如開頭所說, 時間複雜度從O(n2)下降到O(nk). 這個狀況尤爲適用於長度在500之內長度的文本.
2.克服了BERT對於輸入文本長度有限的缺點. 對於tokenize以後長度超過510的文本, 也能夠用此方式對準確率進行再提高, 其實際效果優於直接截斷文本.

5. 進一步拓展: BERT + Transformer

原論文裏還提到了使用Transformer代替LSTM做爲預測部分. 這一節咱們用Transformer來試一下. 咱們先來分析一下使用Transformer結構後的時間複雜度. 顯然它的時間複雜度和LSTM不同(LSTM複雜度咱們能夠認爲是線性的, 即O(n/k)~O(n).) 首先在BERT部分, 時間複雜度不變, 依舊爲爲O(n/k * k2) = O(nk). 進入到Transformer後,每一個sequence長度爲n/k, 因此時間複雜度爲O(n/k * n/k)=O(n2/k2). 那麼整體時間複雜度爲 O(nk) + O(n2/k2) ~ O(n2/k2). 相比於LSTM的O(nk), 這個O(n2/k2)複雜度是有至關的上升的. 可是咱們考慮到 n/k << n, 即n/k的量級遠小於n, 因此仍是在可接受的範圍. (本小節未完…)

相關文章
相關標籤/搜索