TinyBert的原理講解

TinyBERT 是華爲不久前提出的一種蒸餾 BERT 的方法,模型大小不到 BERT 的 1/7,但速度能提升 9 倍。本文梳理了 TinyBERT 的模型結構,探索了其在不一樣業務上的表現,證實了 TinyBERT 對複雜的語義匹配任務來講是一種行之有效的壓縮手段。網絡

1、簡介

在 NLP 領域,BERT 的強大毫無疑問,但因爲模型過於龐大,單個樣本計算一次的開銷動輒上百毫秒,很難應用到實際生產中。TinyBERT 是華爲、華科聯合提出的一種爲基於 transformer 的模型專門設計的知識蒸餾方法,模型大小不到 BERT 的 1/7,但速度提升了 9 倍,並且性能沒有出現明顯降低。目前,該論文已經提交機器學習頂會 ICLR 2020。本文復現了 TinyBERT 的結果,證實了 Tiny BERT 在速度提升的同時,對複雜的語義匹配任務,性能沒有顯著降低。機器學習

目前主流的幾種蒸餾方法大概分紅利用 transformer 結構蒸餾、利用其它簡單的結構好比 BiLSTM 等蒸餾。因爲 BiLSTM 等結構簡單,且通常是用 BERT 最後一層的輸出結果進行蒸餾,不能學到 transformer 中間層的信息,對於複雜的語義匹配任務,效果有點不盡人意。函數

基於 transformer 結構的蒸餾方法目前比較出名的有微軟的 BERT-PKD (Patient Knowledge Distillation for BERT),huggingface 的 DistilBERT,以及本篇文章講的 TinyBERT。他們的基本思路都是減小 transformer encoding 的層數和 hidden size 大小,實現細節上各有不一樣,主要差別體如今 loss 的設計上。性能

2、模型實現細節

TinyBERT 的結構以下圖:學習

整個 TinyBERT 的 loss 設計分爲三部分:編碼

1. Embedding-layer Distillation

其中:spa

分別表明 student 網絡的 embedding 和 teacher 網絡的 embedding. 其中 l 表明 sequence length, d0 表明 student embedding 維度, d 表明 teacher embedding 維度。因爲 student 網絡的 embedding 層一般較 teacher 會變小以得到更小的模型和加速,因此 We 是一個 d 0×d 維的可訓練的線性變換矩陣,把 student 的 embedding 投影到 teacher embedding 所在的空間。最後再算 MSE,獲得 embedding loss.設計

2. Transformer-layer Distillation

TinyBERT 的 transformer 蒸餾採用隔 k 層蒸餾的方式。舉個例子,teacher BERT 一共有 12 層,如果設置 student BERT 爲 4 層,就是每隔 3 層計算一個 transformer loss. 映射函數爲 g(m) = 3 * m, m 爲 student encoder 層數。具體對應爲 student 第 1 層 transformer 對應 teacher 第 3 層,第 2 層對應第 6 層,第 3 層對應第 9 層,第 4 層對應第 12 層。每一層的 transformer loss 又分爲兩部分組成,attention based distillation 和 hidden states based distillation.3d

2.1 Attention based losscode

其中,

h 表明 attention 的頭數,l 表明輸入長度,

表明 student 網絡第 i 個 attention 頭的 attention score 矩陣,

表明 teacher 網絡第 i 個 attention 頭的 attention score 矩陣。這個 loss 是受到斯坦福和 Facebook 聯合發表的論文,What Does BERT Look At? An Analysis of BERT’s Attention 的啓發。這篇論文研究了 attention 權重到底學到了什麼,實驗發現與語義還有語法相關的詞好比第一個動詞賓語,第一個介詞賓語,以及[CLS], [SEP], 逗號等 token,有很高的注意力權重。爲了確保這部分信息能被 student 網絡學到,TinyBERT 在 loss 設計中加上了 student 和 teacher 的 attention matrix 的 MSE。這樣語言知識能夠很好的從 teacher BERT 轉移到 student BERT.

2.2 hidden states based distillation

其中,

分別是 student transformer 和 teacher transformer 的隱層輸出。和 embedding loss 同理,

投影到 Ht 所在的空間。

3. Prediction-Layer Distillation

其中 t 是 temperature value,暫時設爲 1.除了模仿中間層的行爲外,這一層用來模擬 teacher 網絡在 predict 層的表現。具體來講,這一層計算了 teacher 輸出的機率分佈和 student 輸出的機率分佈的 softmax 交叉熵。這一層的實現和具體任務相關,咱們的兩個實驗分別採起了 BERT 原生的 masked language model loss + next sentence loss 和單任務的 classification softmax cross-entropy.

另外,值得一提的是 prediction loss 有不少變化。在 TinyBERT 中,這個 loss 是 teacher BERT 預測的機率和 student BERT 預測機率的 softmax 交叉熵,在 BERT-PKD 模型中,這個 loss 是 teacher BERT 和 student BERT 的交叉熵和 student BERT 和 hard target( one-hot)的交叉熵的加權平均。咱們在業務中有試過直接用 hard target loss,效果比使用 teacher student softmax 交叉熵降低 5-6 個點。由於 softmax 比 one-hot 編碼了更多機率分佈的信息。而且實驗中,softmax cross-entropy loss 容易發生不收斂的狀況,把 softmax 交叉熵改爲 MSE, 收斂效果變好,但泛化效果變差。這是由於使用 softmax cross-entropy 須要學到整個機率分佈,更難收斂,由於擬合了 teacher BERT 的機率分佈,有更強的泛化性。MSE 對極值敏感,收斂的更快,但泛化效果不如前者。

因此總結一下,loss 的計算公式爲:

其中,

相關文章
相關標籤/搜索