點擊上方「AI公園」,關注公衆號,選擇加「星標「或「置頂」git
做者:Sayak Paulweb
編譯:ronghuaiyang
微信
從各個層次給你們講解模型的知識蒸餾的相關內容,並經過實際的代碼給你們進行演示。網絡
公衆號後臺回覆「模型蒸餾」,下載已打包好的代碼。app
本報告討論了很是厲害模型優化技術 —— 知識蒸餾,並給你們過了一遍相關的TensorFlow的代碼。
編輯器
「模型集成是一個至關有保證的方法,能夠得到2%的準確性。「 —— Andrej Karpathy分佈式
我絕對贊成!然而,部署重量級模型的集成在許多狀況下並不老是可行的。有時,你的單個模型可能太大(例如GPT-3),以致於一般不可能將其部署到資源受限的環境中。這就是爲何咱們一直在研究一些模型優化方法 ——量化和剪枝。在這個報告中,咱們將討論一個很是厲害的模型優化技術 —— 知識蒸餾。函數
Softmax告訴了咱們什麼?
當處理一個分類問題時,使用softmax做爲神經網絡的最後一個激活單元是很是典型的用法。這是爲何呢?由於softmax函數接受一組logit爲輸入並輸出離散類別上的機率分佈。好比,手寫數字識別中,神經網絡可能有較高的置信度認爲圖像爲1。不過,也有輕微的可能性認爲圖像爲7。若是咱們只處理像[1,0]這樣的獨熱編碼標籤(其中1和0分別是圖像爲1和7的機率),那麼這些信息就沒法得到。性能
人類已經很好地利用了這種相對關係。更多的例子包括,長得很像貓的狗,棕紅色的,貓同樣的老虎等等。正如Hinton等人所認爲的學習
一輛寶馬被誤認爲是一輛垃圾車的可能性很小,但被誤認爲是一個胡蘿蔔的可能性仍然要高不少倍。
這些知識能夠幫助咱們在各類狀況下進行極好的歸納。這個思考過程幫助咱們更深刻地瞭解咱們的模型對輸入數據的想法。它應該與咱們考慮輸入數據的方式一致。
因此,如今該作什麼?一個迫在眉睫的問題可能會忽然出如今咱們的腦海中 —— 咱們在神經網絡中使用這些知識的最佳方式是什麼?讓咱們在下一節中找出答案。
使用Softmax的信息來教學 —— 知識蒸餾
softmax信息比獨熱編碼標籤更有用。在這個階段,咱們能夠獲得:
-
訓練數據 -
訓練好的神經網絡在測試數據上表現良好
咱們如今感興趣的是使用咱們訓練過的網絡產生的輸出機率。
考慮教人去認識MNIST數據集的英文數字。你的學生可能會問 —— 那個看起來像7嗎?若是是這樣的話,這絕對是個好消息,由於你的學生,確定知道1和7是什麼樣子。做爲一名教師,你可以把你的數字知識傳授給你的學生。這種想法也有可能擴展到神經網絡。
知識蒸餾的高層機制
因此,這是一個高層次的方法:
-
訓練一個在數據集上表現良好神經網絡。這個網絡就是「教師」模型。 -
使用教師模型在相同的數據集上訓練一個學生模型。這裏的問題是,學生模型的大小應該比老師的小得多。
本工做流程簡要闡述了知識蒸餾的思想。
爲何要小?這不是咱們想要的嗎?將一個輕量級模型部署到生產環境中,從而達到足夠的性能。
用圖像分類的例子來學習
對於一個圖像分類的例子,咱們能夠擴展前面的高層思想:
-
訓練一個在圖像數據集上表現良好的教師模型。在這裏,交叉熵損失將根據數據集中的真實標籤計算。 -
在相同的數據集上訓練一個較小的學生模型,可是使用來自教師模型(softmax輸出)的預測做爲ground-truth標籤。這些softmax輸出稱爲軟標籤。稍後會有更詳細的介紹。
咱們爲何要用軟標籤來訓練學生模型?
請記住,在容量方面,咱們的學生模型比教師模型要小。所以,若是你的數據集足夠複雜,那麼較小的student模型可能不太適合捕捉訓練目標所需的隱藏表示。咱們在軟標籤上訓練學生模型來彌補這一點,它提供了比獨熱編碼標籤更有意義的信息。在某種意義上,咱們經過暴露一些訓練數據集來訓練學生模型來模仿教師模型的輸出。
但願這能讓大家對知識蒸餾有一個直觀的理解。在下一節中,咱們將更詳細地瞭解學生模型的訓練機制。
知識蒸餾中的損失函數
爲了訓練學生模型,咱們仍然可使用教師模型的軟標籤以及學生模型的預測來計算常規交叉熵損失。學生模型頗有可能對許多輸入數據點都有信心,而且它會預測出像下面這樣的機率分佈:
擴展Softmax
這些弱機率的問題是,它們沒有捕捉到學生模型有效學習所需的信息。例如,若是機率分佈像[0.99, 0.01]
,幾乎不可能傳遞圖像具備數字7的特徵的知識。
Hinton等人解決這個問題的方法是,在將原始logits傳遞給softmax以前,將教師模型的原始logits按必定的溫度進行縮放。這樣,就會在可用的類標籤中獲得更普遍的分佈。而後用一樣的溫度用於訓練學生模型。
咱們能夠把學生模型的修正損失函數寫成這個方程的形式:
其中,pi是教師模型獲得軟機率分佈,si的表達式爲:
def get_kd_loss(student_logits, teacher_logits,
true_labels, temperature,
alpha, beta):
teacher_probs = tf.nn.softmax(teacher_logits / temperature)
kd_loss = tf.keras.losses.categorical_crossentropy(
teacher_probs, student_logits / temperature,
from_logits=True)
return kd_loss
使用擴展Softmax來合併硬標籤
Hinton等人還探索了在真實標籤(一般是獨熱編碼)和學生模型的預測之間使用傳統交叉熵損失的想法。當訓練數據集很小,而且軟標籤沒有足夠的信號供學生模型採集時,這一點尤爲有用。
當它與擴展的softmax相結合時,這種方法的工做效果明顯更好,而總體損失函數成爲二者之間的加權平均。
def get_kd_loss(student_logits, teacher_logits,
true_labels, temperature,
alpha, beta):
teacher_probs = tf.nn.softmax(teacher_logits / temperature)
kd_loss = tf.keras.losses.categorical_crossentropy(
teacher_probs, student_logits / temperature,
from_logits=True)
ce_loss = tf.keras.losses.sparse_categorical_crossentropy(
true_labels, student_logits, from_logits=True)
total_loss = (alpha * kd_loss) + (beta * ce_loss)
return total_loss / (alpha + beta)
建議β的權重小於α。
在原始Logits上進行操做
Caruana等人操做原始logits,而不是softmax值。這個工做流程以下:
-
這部分保持相同 —— 訓練一個教師模型。這裏交叉熵損失將根據數據集中的真實標籤計算。 -
如今,爲了訓練學生模型,訓練目標變成分別最小化來自教師和學生模型的原始對數之間的平均平方偏差。
mse = tf.keras.losses.MeanSquaredError()
def mse_kd_loss(teacher_logits, student_logits):
return mse(teacher_logits, student_logits)
使用這個損失函數的一個潛在缺點是它是無界的。原始logits能夠捕獲噪聲,而一個小模型可能沒法很好的擬合。這就是爲何爲了使這個損失函數很好地適合蒸餾狀態,學生模型須要更大一點。
Tang等人探索了在兩個損失之間插值的想法:擴展softmax和MSE損失。數學上,它看起來是這樣的:
根據經驗,他們發現當α = 0時,(在NLP任務上)能夠得到最佳的性能。
若是你在這一點上感到有點不知怎麼辦,不要擔憂。但願經過代碼,事情會變得清楚。
一些訓練方法
在本節中,我將向你提供一些在使用知識蒸餾時能夠考慮的訓練方法。
使用數據加強
他們在NLP數據集上展現了這個想法,但這也適用於其餘領域。爲了更好地指導學生模型訓練,使用數據加強會有幫助,特別是當你處理的數據較少的時候。由於咱們一般保持學生模型比教師模型小得多,因此咱們但願學生模型可以得到更多不一樣的數據,從而更好地捕捉領域知識。
使用標記的和未標記的數據訓練學生模型
在像Noisy Student Training和SimCLRV2這樣的文章中,做者在訓練學生模型時使用了額外的未標記數據。所以,你將使用你的teacher模型來生成未標記數據集上的ground-truth分佈。這在很大程度上有助於提升模型的可泛化性。這種方法只有在你所處理的數據集中有未標記數據可用時纔可行。有時,狀況可能並不是如此(例如,醫療保健)。Xie等人探索了數據平衡和數據過濾等技術,以緩解在訓練學生模型時合併未標記數據可能出現的問題。
在訓練教師模型時不要使用標籤平滑
標籤平滑是一種技術,用來放鬆由模型產生的高可信度預測。它有助於減小過擬合,但不建議在訓練教師模型時使用標籤平滑,由於不管如何,它的logits是按必定的溫度縮放的。所以,通常不推薦在知識蒸餾的狀況下使用標籤平滑。
使用更高的溫度值
Hinton等人建議使用更高的溫度值來soften教師模型預測的分佈,這樣軟標籤能夠爲學生模型提供更多的信息。這在處理小型數據集時特別有用。對於更大的數據集,信息能夠經過訓練樣本的數量來得到。
實驗結果
讓咱們先回顧一下實驗設置。我在實驗中使用了Flowers數據集。除非另外指定,我使用如下配置:
-
我使用MobileNetV2做爲基本模型進行微調,學習速度設置爲 1e-5
,Adam做爲優化器。 -
咱們將τ設置爲5。 -
α = 0.9,β = 0.1。 -
對於學生模型,使用下面這個簡單的結構:
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 222, 222, 64) 1792
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 55, 55, 64) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 53, 53, 128) 73856
_________________________________________________________________
global_average_pooling2d_3 ( (None, 128) 0
_________________________________________________________________
dense_3 (Dense) (None, 512) 66048
_________________________________________________________________
dense_4 (Dense) (None, 5) 2565
=================================================================
-
在訓練學生模型時,我使用Adam做爲優化器,學習速度爲 1e-2
。 -
在使用數據加強訓練student模型的過程當中,我使用了與上面提到的相同的默認超參數的加權平均損失。
學生模型基線
爲了使性能比較公平,咱們還從頭開始訓練淺的CNN並觀察它的性能。注意,在本例中,我使用Adam做爲優化器,學習速率爲1e-3
。
訓練循環
在看到結果以前,我想說明一下訓練循環,以及如何在經典的model.fit()
調用中包裝它。這就是訓練循環的樣子:
def train_step(self, data):
images, labels = data
teacher_logits = self.trained_teacher(images)
with tf.GradientTape() as tape:
student_logits = self.student(images)
loss = get_kd_loss(teacher_logits, student_logits)
gradients = tape.gradient(loss, self.student.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
train_loss.update_state(loss)
train_acc.update_state(labels, tf.nn.softmax(student_logits))
t_loss, t_acc = train_loss.result(), train_acc.result()
train_loss.reset_states(), train_acc.reset_states()
return {"loss": t_loss, "accuracy": t_acc}
若是你已經熟悉瞭如何在TensorFlow 2中定製一個訓練循環,那麼train_step()函數應該是一個容易閱讀的函數。注意get_kd_loss()
函數。這能夠是咱們以前討論過的任何損失函數。咱們在這裏使用的是一個訓練過的教師模型,這個模型咱們在前面進行了微調。經過這個訓練循環,咱們能夠建立一個能夠經過.fit()
調用進行訓練完整模型。
首先,建立一個擴展tf.keras.Model
的類。
class Student(tf.keras.Model):
def __init__(self, trained_teacher, student):
super(Student, self).__init__()
self.trained_teacher = trained_teacher
self.student = student
當你擴展tf.keras.Model
類的時候,能夠將自定義的訓練邏輯放到train_step()
函數中(由類提供)。因此,從總體上看,Student類應該是這樣的:
class Student(tf.keras.Model):
def __init__(self, trained_teacher, student):
super(Student, self).__init__()
self.trained_teacher = trained_teacher
self.student = student
def train_step(self, data):
images, labels = data
teacher_logits = self.trained_teacher(images)
with tf.GradientTape() as tape:
student_logits = self.student(images)
loss = get_kd_loss(teacher_logits, student_logits)
gradients = tape.gradient(loss, self.student.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
train_loss.update_state(loss)
train_acc.update_state(labels, tf.nn.softmax(student_logits))
t_loss, t_acc = train_loss.result(), train_acc.result()
train_loss.reset_states(), train_acc.reset_states()
return {"train_loss": t_loss, "train_accuracy": t_acc}
你甚至能夠編寫一個test_step
來自定義模型的評估行爲。咱們的模型如今能夠用如下方式訓練:
student = Student(teacher_model, get_student_model())
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
student.compile(optimizer)
student.fit(train_ds,
validation_data=validation_ds,
epochs=10)
這種方法的一個潛在優點是能夠很容易地合併其餘功能,好比分佈式訓練、自定義回調、混合精度等等。
使用訓練學生模型
用這個損失函數訓練咱們的淺層學生模型,咱們獲得~74%的驗證精度。咱們看到,在epochs 8以後,損失開始增長。這代表,增強正則化可能會有所幫助。另外,請注意,超參數調優過程在這裏有重大影響。在個人實驗中,我沒有作嚴格的超參數調優。爲了更快地進行實驗,我縮短了訓練時間。
使用訓練學生模型
如今讓咱們看看在蒸餾訓練目標中加入ground truth標籤是否有幫助。在β = 0.1和α = 0.1的狀況下,咱們獲得了大約71%的驗證準確性。再次代表,更強的正則化和更長的訓練時間會有所幫助。
使用訓練學生模型
使用了MSE的損失,咱們能夠看到驗證精度大幅降低到~56%。一樣的損失也出現了相似的狀況,這代表須要進行正則化。
請注意,這個損失函數是無界的,咱們的淺學生模型可能沒法處理隨之而來的噪音。讓咱們嘗試一個更深刻的學生模型。
在訓練學生模型的時候使用數據加強
如前所述,學生模式比教師模式的容量更小。在處理較少的數據時,數據加強能夠幫助訓練學生模型。咱們驗證一下。
數據增長的好處是很是明顯的:
-
咱們有一個更好的損失曲線。 -
驗證精度提升到84%。
溫度(τ)的影響
在這個實驗中,咱們研究溫度對學生模型的影響。在這個設置中,我使用了相同的淺層CNN。
從上面的結果能夠看出,當τ爲1時,訓練損失和訓練精度均優於其它方法。對於驗證損失,咱們能夠看到相似的行爲,可是在全部不一樣的溫度下,驗證的準確性彷佛幾乎是相同的。
最後,我想研究下微調基線模是否對學生模型有顯著影響。
基線模型調優的效果
在此次實驗中,我選擇了 EfficientNet B0做爲基礎模型。讓咱們先來看看我用它獲得的微調結果。注意,如前所述,全部其餘超參數都保持其默認值。
咱們在微調步驟中沒有看到任何顯著的改進。我想再次強調,我沒有進行嚴格的超參數調優實驗。基於我從EfficientNet B0獲得的邊際改進,我決定在之後的某個時間點進行進一步的實驗。
第一行對應的是用加權平均損失訓練的默認student model,其餘行分別對應EfficientNet B0和MobileNetV2。注意,我沒有包括在訓練student模型時經過使用數據加強而獲得的結果。
知識蒸餾的一個好處是,它與其餘模型優化技術(如量化和修剪)無縫集成。因此,做爲一個有趣的實驗,我鼓勵大家本身嘗試一下。
總結
知識蒸餾是一種很是有前途的技術,特別適合於用於部署的目的。它的一個優勢是,它能夠與量化和剪枝很是無縫地結合在一塊兒,從而在不影響精度的前提下進一步減少生產模型的尺寸。
英文原文:https://wandb.ai/authors/knowledge-distillation/reports/Distilling-Knowledge-in-Neural-Networks--VmlldzoyMjkxODk
請長按或掃描二維碼關注本公衆號
喜歡的話,請給我個好看吧!
本文分享自微信公衆號 - AI公園(AI_Paradise)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。