衆所周知,GANs 的訓練尤爲困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 作「對抗訓練」,受啓發於 ganhacks,並結合本身的經驗記錄總結了一些經常使用的訓練 GANs 的方法,以備後用。安全
什麼是 GANs?網絡
GANs(Generative Adversarial Networks)能夠說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是爲了區分數據是來源於真實數據仍是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會致使整體的收益增長,這個時候判別器再也沒法區分是生成器生成的數據仍是真實數據。機器學習
GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面獲得普遍研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。函數
GANs 出了什麼問題?性能
GANs 一般被定義爲一個 minimax 的過程:學習
其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感受有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。優化
在原始的 GANs 中,判別器要不斷的提升判別是非的能力,即儘量的將真實樣本分類爲正例,將生成樣本分類爲負例,因此判別器須要優化以下損失函數:設計
做爲對抗訓練,生成器須要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出了以下式的生成器損失函數:3d
因爲在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然能夠足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就很是接近1,致使 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。爲了解決訓練初期階段飽和問題,做者提出了另一個損失函數,即:code
以上面這個兩個生成器目標函數爲例,簡單地分析一下GAN模型存在的幾個問題:
Ian Goodfellow 論文裏面已經給出,固定 G 的參數,咱們獲得最優的 D^*:
也就是說,只有當 P_r=P_g 時候,不論是真實樣本和生成樣本,判別器給出的機率都是 0.5,這個時候就沒法區分樣本究竟是來自於真實樣本仍是來自於生成樣本,這是最理想的狀況。
1. 對於第一種目標函數
在最優判別器下 D^* 下,咱們給損失函數加上一個與 G 無關的項,(3) 式變成:
注意,該式子其實就是判別器的損失函數的相反數。
把最優判別器 D^* 帶入,能夠獲得:
到這裏,咱們就能夠看清楚咱們到底在優化什麼東西了,在最優判別器的狀況下,其實咱們在優化兩個分佈的 JS 散度。固然在訓練過程當中,判別器一開始不是最優的,可是隨着訓練的進行,咱們優化的目標也逐漸接近JS散度,而問題偏偏就出如今這個 JS 散度上面。一個直觀的解釋就是隻要兩個分佈之間的沒有重疊或者重疊部分能夠忽略不計,那麼大機率上咱們優化的目標就變成了一個常數 -2log2,這種狀況經過判別器傳遞給生成器的梯度就是零,也就是說,生成器不可能從判別器那裏學到任何有用的東西,這也就致使了沒法繼續學習。
Arjovsky [2] 以其精湛的數學技巧提供一個更嚴謹的一個數學推導(手動截圖原論文了)。
在 Theorm2.4 成立的狀況下:
拋開上面這些文縐縐的數學表述,其實上面講的核心內容就是當兩個分佈的支撐集是沒有交集的或者說是支撐集是低維的流形空間,隨着訓練的進行,判別器不斷接近最優判別器,會致使生成器的梯度到處都是爲0。
2. 對於第二種目標函數
一樣在最優判別器下,優化 (4) 式等價優化以下
仔細盯着上面式子幾秒鐘,不難發現咱們優化的目標是相互悖論的,由於 KL 散度和 JS 散度的符號相反,優化 KL 是把兩個分佈拉近,可是優化 -JS 是把兩個分佈推遠,這「一推一拉」就會致使梯度更新很是不穩定。此外,咱們知道 KL 不是對稱的,對於生成器沒法生成真實樣本的狀況,KL 對 loss 的貢獻很是大,而對於生成器生成的樣本多樣性不足的時候,KL 對 loss 的貢獻很是小。
而 JS 是對稱的,不會改變 KL 的這種不公平的行爲。這就解釋了咱們常常在訓練階段常常看見兩種狀況,一個是訓練 loss 抖動很是大,訓練不穩定;另一個是即便達到了穩定訓練,生成器也大機率上只生成一些安全保險的樣本,這樣就會致使模型缺少多樣性。
此外,在有監督的機器學習裏面,常常會出現一些過擬合的狀況,然而 GANs 也不例外。當生成器訓練得愈來愈好時候,生成的數據越接近於有限樣本集合裏面的數據。特別是當訓練集裏面包含有錯誤數據時候,判別器會過擬合到這些錯誤的數據,對於那些未見的數據,判別器就不能很好的指導生成器去生成可信的數據。這樣就會致使 GANs 的泛化能力比較差。
綜上所述,原始的 GANs 在訓練穩定性、模式多樣性以及模型泛化性能方面存在着或多或少的問題,後續學術上的工做大多也是基於此進行改進(填坑)。
訓練 GAN 的經常使用策略
上一節都是基於一些簡單的數學或者經驗的分析,可是根本緣由目前沒有一個很好的理論來解釋;儘管理論上的缺陷,咱們仍然能夠從一些經驗中發現一些實用的 tricks,讓你的 GANs 再也不難訓。這裏列舉的一些 tricks 可能跟 ganhacks 裏面的有些重複,更多的是補充,可是爲了完整起見,部分也添加在這裏。
1. model choice
若是你不知道選擇什麼樣的模型,那就選擇 DCGAN[3] 或者 ResNet[4] 做爲 base model。
2. input layer
假如你的輸入是一張圖片,將圖片數值歸一化到 [-1, 1];假如你的輸入是一個隨機噪聲的向量,最好是從 N(0, 1) 的正態分佈裏面採樣,不要從 U(0,1) 的均勻分佈裏採樣。
3. output layer
使用輸出通道爲 3 的卷積做爲最後一層,能夠採用 1x1 或者 3x3 的 filters,有的論文也使用 9x9 的 filters。(注:ganhacks 推薦使用 tanh)
4. transposed convolution layer
在作 decode 的時候,儘可能使用 upsample+conv2d 組合代替 transposed_conv2d,能夠減小 checkerboard 的產生 [5];
在作超分辨率等任務上,能夠採用 pixelshuffle [6]。在 tensorflow 裏,能夠用 tf.depth_to_sapce 來實現 pixelshuffle 操做。
5. convolution layer
因爲筆者常常作圖像修復方向相關的工做,推薦使用 gated-conv2d [7]。
6. normalization
雖然在 resnet 裏的標配是 BN,在分類任務上表現很好,可是圖像生成方面,推薦使用其餘 normlization 方法,例如 parameterized 方法有 instance normalization [8]、layer normalization [9] 等,non-parameterized 方法推薦使用 pixel normalization [10]。假如你有選擇困難症,那就選擇大雜燴的 normalization 方法——switchable normalization [11]。
7. discriminator
想要生成更高清的圖像,推薦 multi-stage discriminator [10]。簡單的作法就是對於輸入圖片,把它下采樣(maxpooling)到不一樣 scale 的大小,輸入三個不一樣參數但結構相同的 discriminator。
8. minibatch discriminator
因爲判別器是單獨處理每張圖片,沒有一個機制能告訴 discriminator 每張圖片之間要儘量的不類似,這樣就會致使判別器會將全部圖片都 push 到一個看起來真實的點,缺少多樣性。minibatch discriminator [22] 就是這樣這個機制,顯式地告訴 discriminator 每張圖片應該要不類似。在 tensorflow 中,一種實現 minibatch discriminator 方式以下:
上面是經過一個可學習的網絡來顯示度量每一個樣本之間的類似度,PGGAN 裏提出了一個更廉價的不須要學習的版本,即經過統計每一個樣本特徵每一個像素點的標準差,而後取他們的平均,把這個平均值複製到與當前 feature map 同樣空間大小單通道,做爲一個額外的 feature maps 拼接到原來的 feature maps 裏,一個簡單的 tensorflow 實現以下:
9. GAN loss
除了第二節提到的原始 GANs 中提出的兩種 loss,還能夠選擇 wgan loss [12]、hinge loss、lsgan loss [13]等。wgan loss 使用 Wasserstein 距離(推土機距離)來度量兩個分佈之間的差別,lsgan 採用相似最小二乘法的思路設計損失函數,最後演變成用皮爾森卡方散度代替了原始 GAN 中的 JS 散度,hinge loss 是遷移了 SVM 裏面的思想,在 SAGAN [14] 和 BigGAN [15] 等都是採用該損失函數。
ps: 我本身常用沒有 relu 的 hinge loss 版本。
10. other loss
一般狀況下,GAN loss 配合上面幾種 loss,效果會更好。
11. gradient penalty
Gradient penalty 首次在 wgan-gp 裏面提出來的,記爲 1-gp,目的是爲了讓 discriminator 知足 1-lipchitchz 連續,後續 Mescheder, Lars M. et al [19] 又提出了只針對正樣本或者負樣本進行梯度懲罰,記爲 0-gp-sample。Thanh-Tung, Hoang et al [20] 提出了 0-gp,具備更好的訓練穩定性。三者的對好比下:
12. Spectral normalization [21]
譜歸一化是另一個讓判別器知足 1-lipchitchz 連續的利器,建議在判別器和生成器裏同時使用。
ps: 在我的實踐中,它比梯度懲罰更有效。
13. one-size label smoothing [22]
平滑正樣本的 label,例如 label 1 變成 0.9-1.1 之間的隨機數,保持負樣本 label 仍然爲 0。我的經驗代表這個 trick 可以有效緩解訓練不穩定的現象,可是不能根本解決問題,假如模型不夠好的話,隨着訓練的進行,後期 loss 會飛。
14. add supervised labels
15. instance noise (decay over time)
在原始 GAN 中,咱們其實在優化兩個分佈的 JS 散度,前面的推理代表在兩個分佈的支撐集沒有交集或者支撐集是低維的流形空間,他們之間的 JS 散度大機率上是 0;而加入 instance noise 就是強行讓兩個分佈的支撐集之間產生交集,這樣 JS 散度就不會爲 0。新的 JS 散度變爲:
16. TTUR [23]
在優化 G 的時候,咱們默認是假定咱們的 D 的判別能力是比當前的 G 的生成能力要好的,這樣 D 才能指導 G 朝更好的方向學習。一般的作法是先更新 D 的參數一次或者屢次,而後再更新 G 的參數,TTUR 提出了一個更簡單的更新策略,即分別爲 D 和 G 設置不一樣的學習率,讓 D 收斂速度更快。
17. training strategy
PGGAN 是一個漸進式的訓練技巧,由於要生成高清(eg, 1024x1024)的圖片,直接從一個隨機噪聲生成這麼高維度的數據是比較難的;既然無法一蹴而就,那就按部就班,首先從簡單的低緯度的開始生成,例如 4x4,而後 16x16,直至咱們所須要的圖片大小。在 PGGAN 裏,首次實現了高清圖片的生成,而且能夠作到以假亂真,可見其威力。此外,因爲咱們大部分的操做都是在比較低的維度上進行的,訓練速度也不比其餘模型遜色多少。
coarse-to-refine 能夠說是 PGGAN 的一個特例,它的作法就是先用一個簡單的模型,加上一個 l1 loss,訓練一個模糊的效果,而後再把這個模糊的照片送到後面的 refine 模型裏,輔助對抗 loss 等其餘 loss,訓練一個更加清晰的效果。這個在圖片生成裏面普遍應用。
18. Exponential Moving Average [24]
EMA主要是對歷史的參數進行一個指數平滑,能夠有效減小訓練的抖動。強烈推薦!!!
總結
訓練 GAN 是一個精(折)細(磨)的活,一不當心你的 GAN 可能就是一部驚悚大片。筆者結合本身的經驗以及看過的一些文獻資料,列出了經常使用的 tricks,在此拋磚引玉,因爲筆者能力和視野有限,有些不正確之處或者沒補全的 tricks,還望斧正。