Deformer: 雙塔模型與匹配模型的結合 - 知乎

做者:曹慶慶(Stony Brook University 在讀 PhD,關注Efficient NLP,QA方向,詳見awk.ai)

背景

BERT、XLNet、RoBERTa等基於Transformer[^transfomer]的預訓練模型推出後,天然語言理解任務都得到了大幅提高。問答任務(Question Answering,QA)[^qa-note]也一樣取得了很大的進步。git

用BERT類模型來作問答或閱讀理解任務,一般須要將問題和問題相關文檔拼接一塊兒做爲輸入文本,而後用自注意力機制對輸入文本進行多層交互編碼,以後用線性分類器判別文檔中可能的答案序列。以下圖:github

雖然這種片斷拼接的輸入方式可讓自注意力機制對所有的token進行交互,獲得的文檔表示是問題相關的(反之亦然),但相關文檔每每很長,token數量通常可達問題文本的10~20倍[^length],這樣就形成了大量的計算。架構

在實際場景下,考慮到設備的運算速度和內存大小,每每會對模型進行壓縮,好比經過蒸餾(distillation)小模型、剪枝(pruning)、量化(quantization)和低軼近似/權重共享等方法。性能

但模型壓縮仍是會帶來必定的精度損失。所以咱們思考,是否是能夠參考雙塔模型的結構,提早進行一些計算,從而提高模型的推理速度?測試

若是這種思路可行,會有幾個很大的優點:編碼

  1. 它不須要大幅修改原來的模型架構
  2. 也不須要從新預訓練,能夠繼續使用標準Transformer初始化+目標數據集fine-tune的精調方式
  3. 還能夠疊加模型壓縮技術

通過不斷地嘗試,咱們提出了《Deformer:Decomposing Pre-trained Transformers for Faster Question Answering》[1],在小幅修改模型架構且不更換預訓練模型的狀況下提高推理速度。下面將爲你們介紹咱們的思考歷程。url

論文連接:https://awk.ai/assets/deformer.pdf spa

代碼連接:https://github.com/StonyBrookNLP/deformer.net

模型結構

在開篇的介紹中,咱們指出了QA任務的計算瓶頸主要在於自注意力機制須要交互編碼的token太多了。所以咱們猜測,是否能讓文檔和問題在編碼階段儘量地獨立?設計

這樣的話,就能夠提早將最難計算的文檔編碼算好,只須要實時編碼較短的問題文本,從而加速整個QA過程。

部分研究代表,Transformer 的低層(lower layers)編碼主要關注一些局部的語言表層特徵(詞形、語法等等),到高層(upper layers)纔開始逐漸編碼與下游任務相關的全局語義信息。所以咱們猜測,至少在模型的某些部分,「文檔編碼可以不依賴於問題」的假設是成立的。 具體來講能夠在 Transformer 開始的低層分別對問題和文檔各自編碼,而後再在高層部分拼接問題和文檔的表徵進行交互編碼,如圖所示:

爲了驗證上述猜測,咱們設計了一個實驗,測量文檔在和不一樣問題交互時編碼的變化程度。下圖爲各層輸出的文檔向量和它們中心點cosine距離的方差:

能夠看到,對於BERT-Based的QA模型,若是編碼的文檔不變而問題變化,模型的低層表徵每每變化不大。這意味着並不是全部Transformer編碼層都須要對整個輸入文本的所有token序列進行自注意力交互。

所以,咱們提出Transformer模型的一種變形計算方式(稱做 DeFormer):在前k層對文檔編碼離線計算獲得第 k 層表徵,問題的第k層表徵經過實時計算,而後拼接問題和文檔的表徵輸入到後面k+1n層。下面這幅圖示意了DeFormer的計算過程:

值得一提的是,這種方式在有些QA任務(好比SQuAD)上有較大的精度損失,因此咱們添加了兩個蒸餾損失項,目的是最小化DeFormer的高層表徵和分類層logits與原始BERT模型的差別,這樣能控制精度損失在1個點左右。

實驗

這裏簡要描述下四組關鍵的實驗結果:

(1)在三個QA任務上,BERT和XLNet採用DeFormer分解後,取得了2.7-3.5倍的加速,節省內存65.8-72.0%,效果損失只有0.6-1.8%。BERT-base(n=12)在SQuAD上,設置k=9能加快推理3.2倍,節省內存70%。

(2)實測了原模型和DeFormer在三種不一樣硬件上的推理延遲。DeFormer均達到3倍以上的加速。

(3)消融實驗證實,添加的兩個蒸餾損失項能起到彌補精度損失的效果。

(4)測試DeFormer分解的層數(對應折線圖橫軸)對推理加速比和性能損失的影響。這個實驗在SQuAD上進行,且沒有使用蒸餾trick。

總結

這篇文章提主要提出了一種變形的計算方式DeFormer,使問題和文檔編碼在低層獨立編碼再在高層交互,從而使得能夠離線計算文檔編碼來加速QA推理和節省內存。

創新之處在於它對原始模型並無太大修改。部署簡單,且效果顯著。 實驗結果代表基於BERT和XLNet的DeFormer均能取得很好的表現。筆者推測對其餘的Transformer模型應該也一樣有效,而且其餘模型壓縮方法和技術應該也能夠疊加使用到DeFormer上來進一步加速模型推理。

[^qa-note]: 嚴格來講是機器閱讀理解,即給出問題從相關文章中提取答案,通常 QA 系統還包括檢索階段來找到問題相關的文檔 [^transfomer]: 論文方面能夠參考邱老師組的文獻綜述:Pre-trained Models for Natural Language Processing: A Survey,實例代碼能夠參見 huggingface 的 transformer 庫 [^length]: 好比 SQuAD 問題平均 10 個 token,但文檔平均有 116 個 token

參考資料

[1] Deformer:Decomposing Pre-trained Transformers for Faster Question Answering: https://awk.ai/assets/deformer.pdf

相關文章
相關標籤/搜索