微調預訓練模型的新姿式——自集成和自蒸餾 - 知乎

論文:Improving BERT Fine-Tuning via Self-Ensemble and Self-Distillation函數

連接:https://arxiv.org/abs/2002.10345性能

做者:Yige Xu, Xipeng Qiu, Ligao Zhou, Xuanjing Huang spa

本文提出了一種自集成和自蒸餾的fine-tuning方法,在不引入外部資源和不顯著增長訓練時間的前提下,能夠進一步加強fine-tuning的效果。blog

Fine-tune完就扔?No!

自蒸餾的前世此生——what、why、how?ip

一、什麼是自蒸餾?

知識蒸餾(Knowledge Distillation)指的是將預訓練好的教師模型(Teacher Model)的知識經過蒸餾的方式遷移到學生模型(Student Model)。自蒸餾(Self-Distillation)則指的是本身蒸餾到本身,Teacher Model就是Student Model的集成版本,稱爲自集成(Self-Ensemble)。集成模型在是刷榜利器,所以咱們但願在訓練過程當中不一樣time step的模型也能夠集成。爲了避免增長訓練開銷,咱們選擇一種參數平均的方式來進行自集成。資源

與同期的工做FastBERT的自蒸餾(高層蒸餾到底層)不一樣,本文的自蒸餾指的是過去time step蒸餾到當前time step。在Fine-tune過程當中,目標函數除了有來自標籤的監督信號之外,還有來自過去time step的監督信號。本文的自蒸餾是爲了進一步提升準確率而不是模型壓縮。get

二、爲何要自蒸餾?

I. 在通常的Fine-tune流程當中,咱們一般只關注某一個epoch結束以後的模型參數,而不關心在Fine-tune過程當中某個time step的參數。那麼Fine-tune的中間過程是否有什麼值得咱們挖掘的信息呢?io

II. 在通常的訓練過程中,咱們一般將數據集劃分紅一個個mini-batch,依次經過模型進行訓練。若是某一個mini-batch的數據質量不過關,可能會將模型參數帶歪,所以是否能夠尋找一種方式來減緩「帶歪」的趨勢呢?ast

III.好的teacher能夠教出更好的學生,而好的學生能夠進一步集成爲更好的教師,經過迭代能夠進行自我加強。class

三、如何進行自蒸餾?

在本文中,咱們提出了兩種自蒸餾的方式:Self-Distillation-Averaged(SDA)和Self-Distillation-Voted(SDV)。在SDA中,咱們首先計算出過去K個time step參數的平均值做爲Teacher Model。在SDV中,咱們將過去K個time step的參數視爲K個Teacher Model。

SDA的目標函數計算方式以下:

\mathcal{L}_{\theta}(x, y)=\mathrm{CE}\Big(\mathrm{BERT}(x,\theta),y \Big)\nonumber+\lambda\mathrm{MSE}\Big(\mathrm{BERT}(x,\theta),\mathrm{BERT}(x,\bar{\theta})\Big)

其中 \bar{\theta} = \frac{1}{K}\sum_{k=1}^{K} \theta_{t-k}

SDV的目標函數計算方式以下: \mathcal{L}_{\theta}(x, y)=\mathrm{CE}\Big(\mathrm{BERT}(x,\theta),y \Big)\nonumber+\lambda\mathrm{MSE}\Big(\mathrm{BERT}(x,\theta),\frac{1}{K}\sum_{k=1}^{K} \mathrm{BERT}(x,\theta_{t-k})\Big)

四、經過自蒸餾咱們能夠獲得什麼?

更穩定的訓練過程

咱們在SNLI數據集當中隨機抽取了1500條訓練數據組成一個迷你訓練集。不改變模型參數初始化,只改變數據訓練順序。經過在這個迷你訓練集上的實驗,咱們發現SDA和SDV加持下的訓練更爲穩定,準確率的均值更高、方差更低。

更高的準確率

在SDA和SDV加持下,能夠有效提高在下游任務Fine-tune BERT的性能。

相關文章
相關標籤/搜索