Distilling the Knowledge in Neural Network
Geoffrey Hinton, Oriol Vinyals, Jeff Dean
preprint arXiv:1503.02531, 2015
NIPS 2014 Deep Learning Workshop算法
主要工做(What)網絡
具體作法(How)機器學習
意義(Why)函數
提升機器學習算法表現的一個簡單方法就是,訓練不一樣模型而後對預測結果取平均。
可是要訓練多個模型會帶來太高的計算複雜度和部署難度。
能夠將集成的知識壓縮在單一的模型中。
論文使用這種方法在MNIST上作實驗,發現取得了不錯的效果。
論文還介紹了一種新型的集成,包括一個或多個完整模型和專用模型,可以學習區分完整模型容易混淆的細粒度的類別。學習
昆蟲有幼蟲期和成蟲期,幼蟲期主要行爲是吸取營養,成蟲期主要行爲是生長繁殖。
相似地,大規模機器學習應用能夠分爲訓練階段和部署階段,訓練階段不要求實時操做,容許訓練一個複雜緩慢的模型,這個模型能夠是分別訓練多個模型的集成,也能夠是單獨的一個很大的帶有強正則好比dropout的模型。
一旦模型訓練好,能夠用不一樣的訓練,這裏稱爲「蒸餾」,去把知識轉移到更適合部署的小模型上。測試
複雜模型學習區分大量的類,一般的訓練目標是最大化正確答案的平均log機率,這麼作有一個反作用就是訓練模型同時也會給全部的錯誤答案分配機率,即便這個機率很小,而有一些機率會比其它的大不少。錯誤答案的相對機率體現了複雜模型的泛化能力。舉個例子,寶馬的圖像被錯認爲垃圾箱的機率很低,可是這被個錯認爲垃圾桶的機率相比於被錯認爲胡蘿蔔的機率來講,是很大的。(能夠認爲模型不止學到了訓練集中的寶馬圖像特徵,還學到了一些別的特徵,好比和垃圾桶共有的一些特徵,這樣就可能捕捉到在新的測試集上的寶馬出現這些的特徵,這就是泛化能力的體現)google
將複雜模型轉爲小模型須要保留模型的泛化能力,一個方法就是用複雜模型產生的分類機率做爲「軟目標」來訓練小模型。
當軟目標的熵值較高時,相對於硬目標,每一個訓練樣本提供更多的信息,訓練樣本之間會有更小的梯度方差。
因此小模型常常能夠被訓練在小數據集上,並且可使用更高的學習率。ci
像MNIST這種分類任務,複雜模型能夠產生很好的表現,大部分信息分佈在小几率的軟目標中。
爲了規避這個問題,Caruana和他的合做者們使用softmax輸出前的units值,而不是softmax後的機率,最小化複雜模型和簡單模型的units的平方偏差來訓練小模型。
而更通用的方法,蒸餾法,先提升softmax的溫度參數直到模型能產生合適的軟目標。而後在訓練小模型匹配軟目標的時候使用相同的溫度T。部署
被用於訓練小模型的轉移訓練集能夠包括未打標籤的數據(能夠沒有原始的實際標籤,由於能夠經過複雜模型獲取一個軟目標做爲標籤),或者使用原始的數據集,使用原始數據集能夠獲得更好的表現。get
softmax公式: $ q_{i} = \frac{exp(z_{i}/T)}{\sum_{j}^{ }exp(z_{j}/T)} $
其中溫度參數T一般設置爲1,T越大能夠獲得更「軟」的機率分佈。
(T越大,不一樣激活值的機率差別越小,全部激活值的機率趨於相同;T越小,不一樣激活值的機率差別越大)
(在蒸餾訓練的時候使用較大的T的緣由是,較小的T對於那些遠小於平均激活值的單元會給予更少的關注,而這些單元是有用的,使用較高的T可以捕捉這些信息)
最簡單的蒸餾形式就是,訓練小模型的時候,以複雜模型獲得的「軟目標」爲目標,採用複雜模型中的較高的T,訓練完以後把T改成1。
當部分或所有轉移訓練集的正確標籤已知時,蒸餾獲得的模型會更優。一個方法就是使用正確標籤來修改軟目標。
可是咱們發現一個更好的方法,簡單對兩個不一樣的目標函數進行權重平均,第一個目標函數是和複雜模型的軟目標作一個交叉熵,使用的複雜模型的溫度T;第二個目標函數是和正確標籤的交叉熵,溫度設置爲1。咱們發現第二個目標函數被分配一個低權重時一般會取得最好的結果。
net | layers | units of each layer | activation | regularization | test errors |
---|---|---|---|---|---|
single net1 | 2 | 1600 | relu | dropout | 67 |
single net2 | 2 | 800 | relu | no | 146 |
(防止表格黏在一塊兒)
net | large net | small net | temperature | test errors |
---|---|---|---|---|
distilled net | single net1 | single net2 | 20 | 74 |
(第一個表格中是兩個單獨的網絡,一個大網絡和一個小網絡。)
(第二個表格是使用了蒸餾的方法,先訓練大網絡,而後根據大網絡的「軟目標」結果和溫度T來訓練小網絡。)
(能夠看到,經過蒸餾的方法將大網絡中的知識壓縮到小網絡中,取得了不錯的效果。)
system | Test Frame Accuracy | Word Error Rate on dev set |
---|---|---|
baseline | 58.9% | 10.9% |
10XEnsemble | 61.1% | 10.7% |
Distilled model | 60.8% | 10.7% |
其中basline的配置爲
10XEnsemble是對baseline訓練10次(隨機初始化爲不一樣參數)而後取平均
蒸餾模型的配置爲
能夠看到,相對於10次集成後的模型表現提高,蒸餾保留了超過80%的效果提高
訓練一個大的集成模型能夠利用並行計算來訓練,訓練完成後把大模型蒸餾成小模型,可是另外一個問題就是,訓練自己就要花費大量的時間,這一節介紹的是如何學習專用模型集合,集合中的每一個模型集中於不一樣的容易混淆的子類集合,這樣能夠減少計算需求。專用模型的主要問題是容易集中於區分細粒度特徵而致使過擬合,可使用軟目標來防止過擬合。
JFT是一個谷歌的內部數據集,有1億的圖像,15000個標籤。google用一個深度卷積神經網絡,訓練了將近6個月。
咱們須要更快的方法來提高baseline模型。
將一個複雜模型分爲兩部分,一部分是一個用於訓練全部數據的通用模型,另外一部分是不少個專用模型,每一個專用模型訓練的數據集是一個容易混淆的子類集合。這些專用模型的softmax結合全部不關心的類爲一類來使模型更小。
爲了減小過擬合,共享學習到的低水平特徵,每一個專用模型用通用模型的權重進行初始化。另外,專用模型的訓練樣本一半來自專用子類集合,另外一半從剩餘訓練集中隨機抽取。
專用模型的子類分組集中於容易混淆的那些類別,雖然計算出了混淆矩陣來尋找聚類,可是可使用一種更簡單的辦法,不須要使用真實標籤來構建聚類。對通用模型的預測結果計算協方差,根據協方差把常常一塊兒預測的類做爲其中一個專用模型的要預測的類別。幾個簡單的例子以下。
JFT 1: Tea party; Easter; Bridal shower; Baby shower; Easter Bunny; ...
JFT 2: Bridge; Cable-stayed bridge; Suspension bridge; Viaduct; Chimney; ...
JFT 3: Toyota Corolla E100; Opel Signum; Opel Astra; Mazda Familia; ...
system | Conditional Test Accuracy | Test Accuracy |
---|---|---|
baseline | 43.1% | 25.0% |
61 specialist models | 45.9% | 26.1% |
對於前面提到過的,對於大量數據訓練好的語音baseline模型,用更少的數據去擬合這個模型的時候,使用軟目標能夠達到更好的效果,減少過擬合。實驗結果以下。
system & training set | Train Frame Accuracy | Test Frame Accuracy |
---|---|---|
baseline(100% training set) | 63.4% | 58.9% |
baseline(3% training set) | 67.3% | 44.5% |
soft targets(3% training set) | 65.4% | 57.0% |