虛擬對抗訓練(VAT):一種用於監督學習和半監督學習的正則化方法

正則化

  虛擬對抗訓練是一種正則化方法,正則化在深度學習中是防止過擬合的一種方法。一般訓練樣本是有限的,而對於深度學習來講,搭設的深度網絡是能夠最大限度地擬合訓練樣本的分佈的,從而致使模型與訓練樣本分佈過度接近,還把訓練樣本中的一些噪聲也擬合進去了,甚至於最極端的,訓練出來的模型只能判斷訓練樣本,而測試樣本變成了隨機判斷。因此爲了讓模型泛化地更好,正則化是頗有必要的。html

  最多見的正則化是直接對模型的參數的大小進行限制,好比將參數(整合爲向量$\theta$)的$L_2$範數:算法

$\displaystyle J(\theta)=\frac{1}{n}\sum\limits_i^n\theta_i^2$網絡

  做爲正則項加入損失函數中,獲得總的損失函數:app

$\displaystyle L(\theta)=\frac{1}{N}\sum\limits_{i=1}^NL(y_i, x_i,\theta)  + \lambda J(\theta)$函數

  從而約束參數不會很大而過於複雜,使模型符合奧卡姆剃刀原理:全部合適的模型中應該選擇最簡單的那個。學習

  然而,這種正則化僅僅符合了奧卡姆剃刀而已,並且它的定義是很模糊的。由於你不知道什麼模型纔是「簡單」的,並且僅僅用範數來限制也不必定就會產生「簡單」的模型,甚至於,「簡單」的模型也未必就是泛化能力強的模型。 測試

對抗訓練

  相較於範數類型的正則項,論文中引用了另外一篇論文,這篇論文從另外一個角度來看待正則化,基於這樣一個假設$A$:對於輸入樣本的微小變更,模型對它的預測輸出也應該不會有很大的改變。這個對於連續函數來講是理所固然的(排除一些梯度特別大的連續函數),可是對於一些神經網絡模型來講,它們內部層與層之間的交互是有閾值的,超過這個閾值才能把信息傳到下一層,致使函數不連續,從而輸入的微小改變就會對後面的輸出產生巨大的影響(論文中指出,僅僅使用$L_p$範數作正則項就容易產生這樣的問題)。它的正則項定義以下:優化

 $\displaystyle J(\theta) = \frac{1}{N}\sum\limits_{i=1}^NL_{adv}(x_i,\theta)$編碼

${\rm where}\,L_{adv}(x_i,\theta) = D[q(y|x_i),p(y|x_i+r_{adv_i},\theta)]$url

${\rm where}\,r_{adv_i}= \mathop{\arg\max}\limits_{r;||r||_2\leq\epsilon}  D[q(y|x_i),p(y|x_i+r,\theta)]$

  這個公式假設模型是生成模型,由於判別模型能夠轉化爲生成模型,因此不另外添加公式了。其中,$D[q,p]$表示分佈$q$和$p$的差別,用交叉熵、相對熵(KL散度)等表達;$q(y|x_i)$表示訓練樣本$x_i$的標籤真實分佈;$p(y|x_i,\theta)$表示模型參數爲$\theta$時對$x_i$的標籤預測分佈;$r_{adv_i}$表示能使$x_i$預測誤差最大化的擾動向量(範數很小)。

  所以,這個正則項的定義就是:在每個訓練樣本點的周圍(固定範圍$\epsilon$),找一個預測分佈和這個樣本點標籤的真實分佈相差最大的樣本點($x_i+r_{adv_i}$),而後優化模型參數$\theta$來減少這個誤差。在每一次迭代優化$\theta$減少損失函數$L(\theta)$以前,都要先計算一次$r_{adv_i}$,即獲取當前$\theta$下使每一個$x_i$誤差最大的擾動向量,進而獲取當前擾動的最大誤差做爲正則項。如此看來好像是在對抗損失函數的減少,所以叫對抗訓練,而 $r_{adv_i}$則叫對抗方向

  由於實際上樣本點的真實連續分佈並不能得到,因此使用離散的機率來做爲分佈,論文中使用one hot vector $h(y=y_{real})$來表達。這個向量是一串0-1編碼,真實標籤對應的向量元素爲1,其它向量元素都爲0,好比標籤有:貓、狗、汽車,則$h(y = 狗)=[0,1,0]$,使用one hot vector的好處之一就是切斷了不一樣標籤之間在連續數值上的聯繫。

  因而咱們很容易能想到,對抗方向應該在$L_{adv}(x_i,\theta) $對$x_i$求梯度時能取到近似(由於在梯度方向函數變化率最大),即:

$\displaystyle r_{adv_i}\approx\epsilon\frac{g_i}{||g_i||_2},\,{\rm where}\,g_i=\nabla_{x}D[h(y=y_{x_i}),p(y|x,\theta)]|_{x=x_i}$ 

  由於須要訓練樣本的真實標籤分佈,因此對抗訓練只適用於監督學習

  論文指出,使用對抗方向來進行擾動的表現是比隨機擾動要好的。隨機擾動就是在$x_i$周圍$\epsilon$內隨機找一個較小的擾動$r_{rad_i}$代替$r_{adv_i}$。儘管隨機擾動的目標也是假設A,可是最終的訓練結果是比對抗擾動差不少的。

虛擬對抗訓練

  虛擬對抗訓練(VAT Visual adversarial training)是基於對抗訓練改進的正則化算法。它主要對對抗訓練進行了兩個地方的改進:

局部平滑度

  在$L_{adv}(x_i,\theta)$定義中的標籤真實分佈$q(y|x_i)$被換成了當前迭代下的標籤預測分佈$p(y|x_i,\hat{\theta})$($\hat{\theta}$表示當前梯度降低下的$\theta$的具體值,而$\theta$則是在損失函數中用來求梯度進行梯度降低的自變量)。另外還給$L_{adv}(x_i,\theta)$換了個名字——LDS(Local distributional smoothness 局部分佈平滑度),定義以下:

${\rm LDS}(x_i,\theta) = D[p(y|x_i,\hat{\theta}),p(y|x_i+r_{vadv_i},\theta)]$

$\,{\rm where}\,r_{vadv_i}=\mathop{\arg\max}\limits_{r;||r||_2\leq\epsilon}  D[p(y|x_i,\hat{\theta}),p(y|x_i+r,\hat{\theta})]$

  咱們可能會疑惑,爲何計算$r_{vadv}$用$\hat{\theta}$,而不用$\theta$,明顯用$\theta$更精確。論文中也沒有給出明確的說明,可能它忘了說明這一點。不過這個細節也的確不容易察覺,在後面我會說一下個人理解。

  能夠發現,${\rm LDS}(x_i,\theta)$並不須要$x_i$的標籤真實分佈,因此即便$x_i$是沒有真實標記的樣本點,一樣能夠加入訓練,所以VAT不但適用於監督學習,還適用於半監督學習。如下是使用VAT的簡化的損失函數($\mathcal{D_l,D_{ul}}$分別爲有標記樣本和無標記樣本集):

$\displaystyle L(\theta)=\sum\limits_{(x,y)\in\mathcal{D_l}}L(y, x,\theta) +\lambda \frac{1}{N_l+N_{ul}}\sum\limits_{x\in\mathcal{D_l,D_{ul}}}{\rm LDS}(x,\theta)$

快速計算rvadv

  對於計算$r_{vadv}$,論文並不直接使用關於$x_i$的梯度。由於顯然$D[p(y|x_i,\hat{\theta}),p(y|x_i+r,\hat{\theta})]$在$r=0$時,兩個分佈徹底相同,熵爲0,若是可導,那麼$x_i$就在極小值點上,從而梯度爲0。因而論文換了一個思考角度,要求$D(r,x_i,\hat{\theta})$(簡化寫法)最大化,不必定只能從梯度的角度考慮。將它關於$r$在0處進行泰勒展開後,由於一階導數(梯度)爲0,發現有以下近似:

$\displaystyle D(r,x_i,\hat{\theta})\approx\frac{1}{2}r^THr+O(r^2)$

  其中$O(r^2)$是$r^2$的高階無窮小,$H=\nabla\nabla_rD(r,x_i,\hat{\theta})|_{r=0}$是Hessian矩陣。由Hessian矩陣的定義可知,該矩陣是實對稱矩陣,必定有對應維數個相互線性無關的特徵向量。由特徵值和特徵向量的定義得,對於範數大小固定的$r$,當$r$是最大特徵值對應的特徵向量時,能取得$r^THr$最大,又由於$r$的範數很小,後面的高階無窮小能夠忽略不計,相應地,$D(r,x_i,\hat{\theta})$也取得最大。因此:

$r_{vadv}\approx\mathop{\arg\max}\limits_{r;||r||_2\leq\epsilon}r^THr=\epsilon\overline{u}$

  其中$\overline{u}$表示$H$的最大特徵值對應的單位特徵向量。可是,計算高維的Hessian矩陣是很困難的,更不用說再計算它的特徵值和特徵向量了。因此,論文使用冪法(冪迭代法,具體算法看此連接)來計算矩陣最大特徵值對應的特徵向量。即隨機取一個同維度的向量$d$(假設用特徵向量表達$d$時,$u$的係數不爲0),進行如下迭代:

$d=\overline{Hd}$

  迭代到後期,$d$會無限接近於$\overline{u}$。而後,論文又用所謂的有限差分法,來避免計算 Hessian矩陣。有限差分法就是用所謂的差商代替微商來近似計算導數,差商就是用比較小的因變量除以對應的自變量,微商就是用因變量的極限(無限小)除以對應自變量的極限。因而,0處的「二階導數」$H$乘上一個較小的自變量$\xi d$,就能夠近似0到$\xi d$處的一階導數(梯度)的變化量:

$\xi Hd\approx\nabla_rD(r,x_i,\hat{\theta})|_{r=\xi d}-\nabla_rD(r,x_i,\hat{\theta})|_{r=0}$

  因爲$r=0$處的梯度爲0:

$\displaystyle Hd\approx\frac{\nabla_rD(r,x_i,\hat{\theta})|_{r=\xi d}}{\xi}$

   因此迭代式變爲:

$d=\overline{\nabla_rD(r,x_i,\hat{\theta})|_{r=\xi d}}$

  論文中實驗,迭代一次就能獲取很好的近似$u$的效果。即:

$\displaystyle r_{vadv}\approx\epsilon\frac{g}{||g||_2}$

${\rm where}\,g=\nabla_rD[p(y|x_i,\hat{\theta}),p(y|x_i+r,\hat{\theta})]|_{r=\xi d}$

  我以爲迭代一次的緣由應該是:相較迭代獲取精度更高的虛擬對抗方向,計算力省下來用於梯度降低,更快地收斂整個模型更好。或者梯度降低前期迭代近似$r_{vadv}$次數少一些,後期再逐漸增長迭代次數增長收尾時的精度。

  說一下我對爲何要用$\hat{\theta}$,而不用$\theta$的理解。由於須要計算$r=\xi d$處的梯度並進行迭代,若是使用不能當具體數值參與計算的參數$\theta$,就只能把整個迭代寫成一次性計算的算式形式了,並且不能動態改變迭代的次數。而且隨着迭代次數增多,參數$\theta$的數量會指數式上升。固然,若是和上面同樣只迭代一次,我以爲是可使用$\theta$的。不過論文第6頁左上角好像說明了這點,當時沒看懂,說的應該就是這個意思:

額外正則項

  另外,在實驗中,論文除了LDS正則項外,還添加了條件熵做爲額外的正則項。定義以下:

$\displaystyle\mathcal{H}(Y|X)=-\frac{1}{N_l+N_{ul}}\sum\limits_{x\in \mathcal{D_l,D_{ul}}}\sum\limits_{y}p(y|x,\theta)\log p(y|x,\theta)$

  表示除了類似輸入應該有類似輸出外(減少LDS),輸出標籤的機率分佈還應該越集中越好(減少$\mathcal{H}(Y|X)$)。由於在$X$條件下$Y$的混亂度(熵)表明了輸出機率分佈的不集中度的平均值,因此優化條件熵越小,輸出機率分佈越集中、越肯定。而預測地越明確越好天然是咱們想要的。

VAT效果

  下圖展現了使用VAT進行半監督訓練的過程:

  圖中方形圖標是有標籤訓練樣本,圓形圖標是無標籤訓練樣本。分紅上下兩部分,分別展現了在訓練以前、訓練更新(梯度降低)10次、100次、1000次時,模型對無標籤訓練樣本的預測狀況$({\rm I})$,和無標籤訓練樣本的LDS$({\rm II})$。樣本的輸入爲二維,分別用橫縱座標表示。模型預測輸出爲一維,從綠到灰,再到紫,用連續的顏色過渡來表示預測標籤爲某個類別的機率(紫色機率爲1,綠色機率爲0,灰色爲0.5),如$({\rm I})$所示。$({\rm II})$用灰色到紫色表示無標籤樣本的LDS大小,越紫說明該樣本點在當前模型下的LDS越大,說明對這個樣本點進行小擾動會使當前模型的預測出現大誤差。

  $({\rm I})$能夠看出,隨着不斷的更新,無標籤樣本的預測從有標籤樣本「傳染」出去(由於遵循相近的樣本預測相同的理念),直到停在無標籤樣本稀疏的地方(由於沒有樣本再進行減少LDS的「傳染」,而稀疏的地方也正好就是兩個類別的分界線),最終造成了兩個鑲嵌着的半圓環。這個「傳染」的效果是我以前沒想到的,我覺得減少LDS的效果僅僅侷限在有標籤樣本的周圍。可是加了大量的無標籤樣本後,這些樣本對模型進行了整體的「把控」,而少許的有標籤樣本則對這個整體進行了「固定」,兩者聯動,使得VAT半監督學習的學習效果很好。

  $({\rm II})$顯示LDS隨着模型的更新,愈來愈小,最後LDS較大大的樣本點都分佈在兩個標籤的分界線處。

論文信息

  Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning

相關文章
相關標籤/搜索