論文筆記:蒸餾網絡(Distilling the Knowledge in Neural Network)

Distilling the Knowledge in Neural Network
Geoffrey Hinton, Oriol Vinyals, Jeff Dean
preprint arXiv:1503.02531, 2015
NIPS 2014 Deep Learning Workshop算法

簡單總結

主要工做(What)網絡

  1. 「蒸餾」(distillation):把大網絡的知識壓縮成小網絡的一種方法
  2. 「專用模型」(specialist models):對於一個大網絡,能夠訓練多個專用網絡來提高大網絡的模型表現

具體作法(How)機器學習

  1. 蒸餾:先訓練好一個大網絡,在最後的softmax層使用合適的溫度參數T,最後訓練獲得的機率稱爲「軟目標」。以這個軟目標和真實標籤做爲目標,去訓練一個比較小的網絡,訓練的時候也使用在大模型中肯定的溫度參數T
  2. 專用模型:對於一個已經訓練好的大網絡,能夠訓練一系列的專用模型,每一個專用模型只訓練一部分專用的類以及一個「不屬於這些專用類的其它類」,好比專用模型1訓練的類包括「顯示器」,「鼠標」,「鍵盤」,...,「其它」;專用模型2訓練的類包括「玻璃杯」,「保溫杯」,「塑料杯」,「其它「。最後以專用模型和大網絡的預測輸出做爲目標,訓練一個最終的網絡來擬合這個目標。

意義(Why)函數

  1. 蒸餾把大網絡壓成小網絡,這樣就能夠先在訓練階段花費大精力訓練一個大網絡,而後在部署階段以較小的計算代價來產生一個較小的網絡,同時保持必定的網絡預測表現。
  2. 對於一個已經訓練好的大網絡,若是要去作集成的話計算開銷是很大的,能夠在這個基礎上訓練一系列專用模型,由於這些模型一般比較小,因此訓練會快不少,並且有了這些專用模型的輸出能夠獲得一個軟目標,實驗證實使用軟目標訓練能夠減少過擬合。最後根據這個大網絡和一系列專用模型的輸出做爲目標,訓練一個最終的網絡,能夠獲得不錯的表現,並且不須要對大網絡作大量的集成計算

Abstract

提升機器學習算法表現的一個簡單方法就是,訓練不一樣模型而後對預測結果取平均。
可是要訓練多個模型會帶來太高的計算複雜度和部署難度。
能夠將集成的知識壓縮在單一的模型中。
論文使用這種方法在MNIST上作實驗,發現取得了不錯的效果。
論文還介紹了一種新型的集成,包括一個或多個完整模型和專用模型,可以學習區分完整模型容易混淆的細粒度的類別。學習

1 Introduction

昆蟲有幼蟲期和成蟲期,幼蟲期主要行爲是吸取營養,成蟲期主要行爲是生長繁殖。
相似地,大規模機器學習應用能夠分爲訓練階段和部署階段,訓練階段不要求實時操做,容許訓練一個複雜緩慢的模型,這個模型能夠是分別訓練多個模型的集成,也能夠是單獨的一個很大的帶有強正則好比dropout的模型。
一旦模型訓練好,能夠用不一樣的訓練,這裏稱爲「蒸餾」,去把知識轉移到更適合部署的小模型上。測試

複雜模型學習區分大量的類,一般的訓練目標是最大化正確答案的平均log機率,這麼作有一個反作用就是訓練模型同時也會給全部的錯誤答案分配機率,即便這個機率很小,而有一些機率會比其它的大不少。錯誤答案的相對機率體現了複雜模型的泛化能力。舉個例子,寶馬的圖像被錯認爲垃圾箱的機率很低,可是這被個錯認爲垃圾桶的機率相比於被錯認爲胡蘿蔔的機率來講,是很大的。(能夠認爲模型不止學到了訓練集中的寶馬圖像特徵,還學到了一些別的特徵,好比和垃圾桶共有的一些特徵,這樣就可能捕捉到在新的測試集上的寶馬出現這些的特徵,這就是泛化能力的體現)google

將複雜模型轉爲小模型須要保留模型的泛化能力,一個方法就是用複雜模型產生的分類機率做爲「軟目標」來訓練小模型。
當軟目標的熵值較高時,相對於硬目標,每一個訓練樣本提供更多的信息,訓練樣本之間會有更小的梯度方差。
因此小模型常常能夠被訓練在小數據集上,並且可使用更高的學習率。ci

像MNIST這種分類任務,複雜模型能夠產生很好的表現,大部分信息分佈在小几率的軟目標中。
爲了規避這個問題,Caruana和他的合做者們使用softmax輸出前的units值,而不是softmax後的機率,最小化複雜模型和簡單模型的units的平方偏差來訓練小模型。
而更通用的方法,蒸餾法,先提升softmax的溫度參數直到模型能產生合適的軟目標。而後在訓練小模型匹配軟目標的時候使用相同的溫度T。部署

被用於訓練小模型的轉移訓練集能夠包括未打標籤的數據(能夠沒有原始的實際標籤,由於能夠經過複雜模型獲取一個軟目標做爲標籤),或者使用原始的數據集,使用原始數據集能夠獲得更好的表現。get

2 Distillation

softmax公式: $ q_{i} = \frac{exp(z_{i}/T)}{\sum_{j}^{ }exp(z_{j}/T)} $
其中溫度參數T一般設置爲1,T越大能夠獲得更「軟」的機率分佈。
T越大,不一樣激活值的機率差別越小,全部激活值的機率趨於相同;T越小,不一樣激活值的機率差別越大
在蒸餾訓練的時候使用較大的T的緣由是,較小的T對於那些遠小於平均激活值的單元會給予更少的關注,而這些單元是有用的,使用較高的T可以捕捉這些信息

最簡單的蒸餾形式就是,訓練小模型的時候,以複雜模型獲得的「軟目標」爲目標,採用複雜模型中的較高的T,訓練完以後把T改成1。

當部分或所有轉移訓練集的正確標籤已知時,蒸餾獲得的模型會更優。一個方法就是使用正確標籤來修改軟目標。
可是咱們發現一個更好的方法,簡單對兩個不一樣的目標函數進行權重平均,第一個目標函數是和複雜模型的軟目標作一個交叉熵,使用的複雜模型的溫度T;第二個目標函數是和正確標籤的交叉熵,溫度設置爲1。咱們發現第二個目標函數被分配一個低權重時一般會取得最好的結果。

3 Preliminary experiments on MNIST

net layers units of each layer activation regularization test errors
single net1 2 1600 relu dropout 67
single net2 2 800 relu no 146

(防止表格黏在一塊兒)

net large net small net temperature test errors
distilled net single net1 single net2 20 74

第一個表格中是兩個單獨的網絡,一個大網絡和一個小網絡。
第二個表格是使用了蒸餾的方法,先訓練大網絡,而後根據大網絡的「軟目標」結果和溫度T來訓練小網絡。
能夠看到,經過蒸餾的方法將大網絡中的知識壓縮到小網絡中,取得了不錯的效果。

4 Experiments on speech recognition

system Test Frame Accuracy Word Error Rate on dev set
baseline 58.9% 10.9%
10XEnsemble 61.1% 10.7%
Distilled model 60.8% 10.7%

其中basline的配置爲

  • 8 層,每層2560個relu單元
  • softmax層的單元數爲14000
  • 訓練樣本大小約爲 700M,2000個小時的語音文本數據

10XEnsemble是對baseline訓練10次(隨機初始化爲不一樣參數)而後取平均

蒸餾模型的配置爲

  • 使用的候選溫度爲{1, 2, 5, 10}, 其中T爲2時表現最好
  • hard target 的目標函數給予0.5的相對權重

能夠看到,相對於10次集成後的模型表現提高,蒸餾保留了超過80%的效果提高

5 Training ensembles of specialists on very big datasets

訓練一個大的集成模型能夠利用並行計算來訓練,訓練完成後把大模型蒸餾成小模型,可是另外一個問題就是,訓練自己就要花費大量的時間,這一節介紹的是如何學習專用模型集合,集合中的每一個模型集中於不一樣的容易混淆的子類集合,這樣能夠減少計算需求。專用模型的主要問題是容易集中於區分細粒度特徵而致使過擬合,可使用軟目標來防止過擬合。

5.1 JFT數據集

JFT是一個谷歌的內部數據集,有1億的圖像,15000個標籤。google用一個深度卷積神經網絡,訓練了將近6個月。
咱們須要更快的方法來提高baseline模型。

5.2 專用模型

將一個複雜模型分爲兩部分,一部分是一個用於訓練全部數據的通用模型,另外一部分是不少個專用模型,每一個專用模型訓練的數據集是一個容易混淆的子類集合。這些專用模型的softmax結合全部不關心的類爲一類來使模型更小。

爲了減小過擬合,共享學習到的低水平特徵,每一個專用模型用通用模型的權重進行初始化。另外,專用模型的訓練樣本一半來自專用子類集合,另外一半從剩餘訓練集中隨機抽取。

5.3 將子類分配到專用模型

專用模型的子類分組集中於容易混淆的那些類別,雖然計算出了混淆矩陣來尋找聚類,可是可使用一種更簡單的辦法,不須要使用真實標籤來構建聚類。對通用模型的預測結果計算協方差,根據協方差把常常一塊兒預測的類做爲其中一個專用模型的要預測的類別。幾個簡單的例子以下。

JFT 1: Tea party; Easter; Bridal shower; Baby shower; Easter Bunny; ...
JFT 2: Bridge; Cable-stayed bridge; Suspension bridge; Viaduct; Chimney; ...
JFT 3: Toyota Corolla E100; Opel Signum; Opel Astra; Mazda Familia; ...

5.4 實驗表現

system Conditional Test Accuracy Test Accuracy
baseline 43.1% 25.0%
61 specialist models 45.9% 26.1%

6 Soft Targets as Regularizers

對於前面提到過的,對於大量數據訓練好的語音baseline模型,用更少的數據去擬合這個模型的時候,使用軟目標能夠達到更好的效果,減少過擬合。實驗結果以下。

system & training set Train Frame Accuracy Test Frame Accuracy
baseline(100% training set) 63.4% 58.9%
baseline(3% training set) 67.3% 44.5%
soft targets(3% training set) 65.4% 57.0%
相關文章
相關標籤/搜索