基於BERT的超長文本分類模型
0.Abstract
本文實現了一個基於BERT+LSTM超長文本分類的模型, 評估方法使用準確率和F1 Score.
項目代碼github地址: https://github.com/neesetifa/bert_classification
git
1.任務介紹
用BERT作文本分類是一個比較常見的項目.
可是衆所周知BERT對於文本輸入長度有限制. 對於超長文本的處理, 最簡單暴力無腦高效的辦法是直接截斷, 就取開頭這部分送入BERT. 可是也請別看不起這種作法, 每每最簡單,最Naive的方法效果反而比一頓操做猛如虎 複雜模型來得好.
github
這裏多提一句爲何. 一般長文本的文章結構都比較明確, 文章前面一兩段基本都是對於後面的概述. 因此等於做者已經幫你提取了文章大意, 因此直接取前面一部分理論上來講是有意義的.
固然也有最新研究代表取文章中間部分效果也很不錯. 在此不展開.
網絡
本文實現的是一種基於HIERARCHICAL(級聯)思想的作法, 把文本切成多片處理. 該方法來自於這篇論文 <Hierarchical Transformers for Long Document Classification>.
文中提到這麼作還能下降self-attention計算的時間複雜度.
假設原句子長爲n, 每一個分段的長度是k. 咱們知道最原始的BERT計算時間複雜度是O(n2), 做者認爲,這麼作能夠把時間複雜度下降到O(nk). 由於咱們把n分數據分割成k小份, 那麼咱們一共要作n/k次, 每次咱們的時間複雜度是k2, 即O(n/k * k2) = O(nk)
app
此次咱們測試該模型在兩種語言上的效果. 分別是中文數據集和英語數據集.
中文數據集依舊是咱們的老朋友ChineseNLPCorps提供的不一樣類別商品的評論.
中文數據集傳送門
英語數據集來源於Kaggle比賽, 用戶對於不一樣金融產品的評論.
英語數據集傳送門
因爲兩種數據集訓練預測上沒有什麼本質區別, 下文會用英語數據集來演示.
學習
本項目使用的評估方法是準確率和F1 Score. 很是常見的分類問題評價標準.優化
此項目中直接取了數據集裏一小部分做爲測試集.spa
2.數據初步處理
數據集裏有55W條數據,18個features.
咱們須要的部分是product(即商品類別)以及consumer complaint narrative.
觀察數據集,咱們發現用戶評論是有NaN值的. 並且本次實驗目的是作超長文本分類. 咱們選取非NaN值,而且是長度大於250的評論.
篩選完後咱們保留大約17k條左右數據
3.Baseline模型
咱們先來看一下什麼都不作, 直接用BERT進行finetune能達到什麼樣的效果. 咱們以此做爲實驗的baseline.
本次預訓練模型使用google官方的BERT-base-cased英語預訓練模型(固然用uncased應該也不要緊, 我沒有測試)
fine-tune部分很簡單, 直接提取[CLS] token後過線性層, 是比較常規的套路. 損失函數使用cross entropy loss.
文本送入的最大長度定爲250. 即前文裏提到的"直接截取文本前面部分". 這次實驗裏咱們嘗試比較HIERARCHICAL方法能比直接截取提升多少.
如圖, 準確率達到了88%. 訓練數據不過10k的數量級, 對於深度學習來講是很是少的. 這裏不得不感嘆下BERT做爲預訓練模型在小樣本數據上的實力很是強勁.
4. 數據進一步處理
接下來咱們進入提升部分. 首先對數據進一步處理.
HIERARCHICAL思想本質是對數據進行有重疊(overlap)的分割. 這樣分割後的每句句子之間仍然保留了必定的關聯信息.
衆所周知,BERT輸入的最大長度限制爲512, 其中還須要包括[CLS]和[SEP]. 那麼實際可用的長度僅爲510. 可是別忘了, 每一個單詞tokenizer以後也有可能被分紅好幾部分. 因此實際可輸入的句子長度遠不足510.
本次實驗裏咱們設置分割的長度爲200, overlap長度爲50. 若是實際上線生產確有大量超過500長度的文本, 只需將分割和overlap長度設置更長便可.
def get_split_text(text, split_len=250, overlap_len=50): split_text=[] for w in range(len(text)//split_len): if w == 0: #第一次,直接分割長度放進去 text_piece = text[:split_len] else: # 不然, 按照(分割長度-overlap)日後走 window = split_len - overlap_len text_piece = [w * window: w * window + split_len] split_text.append(text_piece) return split_text
分割完後長這樣
隨後咱們將這些分割的句子分離成單獨的一條數據. 併爲他們加上label.
對比原文本能夠發現, index 1~ index4來源於同一句句子. 它被分割成了4份而且每份都擁有原文本的label.
4.最終模型
最終模型由兩個部分構成, 第一部分是和baseline裏如出一轍的, fine-tune後的BERT. 第二部分是由LSTM+FC層組成的混合模型.
即實際上, BERT只是用來提取出句子的表示, 而真正在作分類的是LSTM + FC部分(更準確來講是FC部分, 由於LSTM模型部分仍然在作進一步的特徵提取工做)
這裏稍微提一句,這樣作法我我的認爲相似於廣告推薦系統裏GBDT+LR的組合. 採用一個稍微複雜的模型去作特徵提取, 而後用一個相對簡單的模型去預測.
首先,咱們把分割好後的文本送入BERT進行訓練. 這邊我跑了5個epoch, 顯卡仍然是Tesla K80, 每一個epoch大約須要23分鐘左右.
接着, 咱們提取出這些文本的句子表示.
方便起見, 咱們這裏仍然用[CLS] token做爲句子表示. 固然也能夠用sequence_output(在我上一個項目FAQ問答的最後結論中, 使用sequence_output的確能比pooled_output效果更好一點)
咱們得到的是這樣一組數據:
句子1_a的embedding, label
句子1_b的embedding, label
句子1_c的embedding, label
句子2_a的embedding, label
句子2_b的embedding, label
句子3_a的embedding, label
…
隨後咱們把這些embedding拼回起來, 變成了
[句子1_a的embedding,句子1_b的embedding, 句子1_c的embedding], label
[句子2_a的embedding, 句子2_b的embedding], label
[句子3_a的embedding, 句子3_b的embedding], label
這部分數據將做爲LSTM部分的輸入.
這一步,咱們將上一步獲得的embedding直接送入LSTM網絡訓練.
回想一下, 咱們平時用LSTM作, 是否是把句子過了embedding層以後再送入LSTM的? 這裏咱們直接跳過embedding層, 由於咱們的數據自己就是embedding
因爲分割後的embedding都不會太長, 咱們直接使用LSTM最後一個time step的輸出(固然這裏也有個嘗試點, 若是提取出LSTM每一個time step的輸出效果是否是會更好?)
LSTM以後會過一個激活函數, 接一個FC層, FC層和label用cross entropy loss進行優化.
因爲合併後的數據量比較小, 我跑了10個epoch, 每次都很快.
(左邊loss, 右邊accuracy)
最終效果竟然提升到了94%!! 說實話這個提高量遠高於論文. 可能和數據自己好也有關係. 可是咱們能夠認爲, 比起直接截取文本開頭一段, 採用HIERARCHICAL方式不只克服了BERT長度限制的缺點, 也極大提高了對於超長文本的分類效果.
下面是在中文數據集上模型的baseline效果和提高後的效果.
(待跑)
因此我認爲, 採用HIERARCHICAL方法, 提高/解決了BERT兩方面的缺點:
1.下降了BERT裏self-attention部分計算的時間複雜度. 就如開頭所說, 時間複雜度從O(n2)下降到O(nk). 這個狀況尤爲適用於長度在500之內長度的文本.
2.克服了BERT對於輸入文本長度有限的缺點. 對於tokenize以後長度超過510的文本, 也能夠用此方式對準確率進行再提高, 其實際效果優於直接截斷文本.
5. 進一步拓展: BERT + Transformer
原論文裏還提到了使用Transformer代替LSTM做爲預測部分. 這一節咱們用Transformer來試一下. 咱們先來分析一下使用Transformer結構後的時間複雜度. 顯然它的時間複雜度和LSTM不同(LSTM複雜度咱們能夠認爲是線性的, 即O(n/k)~O(n).) 首先在BERT部分, 時間複雜度不變, 依舊爲爲O(n/k * k2) = O(nk). 進入到Transformer後,每一個sequence長度爲n/k, 因此時間複雜度爲O(n/k * n/k)=O(n2/k2). 那麼整體時間複雜度爲 O(nk) + O(n2/k2) ~ O(n2/k2). 相比於LSTM的O(nk), 這個O(n2/k2)複雜度是有至關的上升的. 可是咱們考慮到 n/k << n, 即n/k的量級遠小於n, 因此仍是在可接受的範圍. (本小節未完…)