【DL】模型蒸餾Distillation

過去一直follow着transformer系列模型的進展,從BERT到GPT2再到XLNet。然而隨着模型體積增大,線上性能也愈來愈差,因此決定開一條新線,開始follow模型壓縮之模型蒸餾的故事線。函數

Hinton在NIPS2014提出了知識蒸餾(Knowledge Distillation)的概念,旨在把一個大模型或者多個模型ensemble學到的知識遷移到另外一個輕量級單模型上,方便部署。簡單的說就是用新的小模型去學習大模型的預測結果,改變一下目標函數。聽起來是不難,但在實踐中小模型真的能擬合那麼好嗎?因此仍是要多看看別人家的實驗,掌握一些trick。性能

0. 名詞解釋

  • teacher - 原始模型或模型ensemble
  • student - 新模型
  • transfer set - 用來遷移teacher知識、訓練student的數據集合
  • soft target - teacher輸出的預測結果(通常是softmax以後的機率)
  • hard target - 樣本本來的標籤
  • temperature - 蒸餾目標函數中的超參數

1. 基本思想

1.1 爲何蒸餾能夠work

好模型的目標不是擬合訓練數據,而是學習如何泛化到新的數據。因此蒸餾的目標是讓student學習到teacher的泛化能力,理論上獲得的結果會比單純擬合訓練數據的student要好。另外,對於分類任務,若是soft targets的熵比hard targets高,那顯然student會學習到更多的信息。學習

1.2 蒸餾時的softmax

[公式]

比以前的softmax多了一個參數T(temperature),T越大產生的機率分佈越平滑。orm

有兩種蒸餾的目標函數:cdn

  1. 只使用soft targets:在蒸餾時teacher使用新的softmax產生soft targets;student使用新的softmax在transfer set上學習,和teacher使用相同的T。
  2. 同時使用sotf和hard targets:student的目標函數是hard target和soft target目標函數的加權平均,使用hard target時T=1,soft target時T和teacher的同樣。Hinton的經驗是給hard target的權重小一點。另外要注意的是,由於在求梯度(導數)時新的目標函數會致使梯度是之前的 [公式] ,因此要再乘上 [公式] ,否則T變了的話hard target不減少(T=1),但soft target會變。

2. 蒸餾經驗

(我去旅遊了)部署

相關文章
相關標籤/搜索