[GAN學習系列] 初識GAN

本文大約 3800 字,閱讀大約須要 8 分鐘git

要說最近幾年在深度學習領域最火的莫過於生成對抗網絡,即 Generative Adversarial Networks(GANs)了。它是 Ian Goodfellow 在 2014 年發表的,也是這四年來出現的各類 GAN 的變種的開山鼻祖了,下圖表示這四年來有關 GAN 的論文的每月發表數量,能夠看出在 2014 年提出後到 2016 年相關的論文是比較少的,可是從 2016 年,或者是 2017 年到今年這兩年的時間,相關的論文是真的呈現井噴式增加。github

那麼,GAN 到底是什麼呢,它爲什麼會成爲這幾年這麼火的一個研究領域呢?微信

GAN,即生成對抗網絡,是一個生成模型,也是半監督和無監督學習模型,它能夠在不須要大量標註數據的狀況下學習深度表徵。最大的特色就是提出了一種讓兩個深度網絡對抗訓練的方法。網絡

目前機器學習按照數據集是否有標籤能夠分爲三種,監督學習、半監督學習和無監督學習,發展最成熟,效果最好的目前仍是監督學習的方法,可是在數據集數量要求更多更大的狀況下,獲取標籤的成本也更加昂貴了,所以愈來愈多的研究人員都但願可以在無監督學習方面有更好的發展,而 GAN 的出現,一來它是不太須要不少標註數據,甚至能夠不須要標籤,二來它能夠作到不少事情,目前對它的應用包括圖像合成、圖像編輯、風格遷移、圖像超分辨率以及圖像轉換等。app

好比字體的轉換,在 zi2zi 這個項目中,給出了對中文文字的字體的變換,效果以下圖所示,GAN 能夠學習到不一樣字體,而後將其進行變換。機器學習

zi2zi_examples

除了字體的學習,還有對圖片的轉換, pix2pix 就能夠作到,其結果以下圖所示,分割圖變成真實照片,從黑白圖變成彩色圖,從線條畫變成富含紋理、陰影和光澤的圖等等,這些都是這個 pix2pixGAN 實現的結果。ide

pix2pix_examples

CycleGAN 則能夠作到風格遷移,其實現結果以下圖所示,真實照片變成印象畫,普通的馬和斑馬的互換,季節的變換等。函數

cycleGAN_examples

上述是 GAN 的一些應用例子,接下來會簡單介紹 GAN 的原理以及其優缺點,固然也還有爲啥等它提出兩年後纔開始有愈來愈多的 GAN 相關的論文發表。學習

1. 基本原理

GAN 的思想其實很是簡單,就是生成器網絡和判別器網絡的彼此博弈。測試

GAN 主要就是兩個網絡組成,生成器網絡(Generator)和判別器網絡(Discriminator),經過這兩個網絡的互相博弈,讓生成器網絡最終可以學習到輸入數據的分佈,這也就是 GAN 想達到的目的--學習輸入數據的分佈。其基本結構以下圖所示,從下圖能夠更好理解G 和 D 的功能,分別爲:

  • D 是判別器,負責對輸入的真實數據和由 G 生成的假數據進行判斷,其輸出是 0 和 1,即它本質上是一個二值分類器,目標就是對輸入爲真實數據輸出是 1,對假數據的輸入,輸出是 0;
  • G 是生成器,它接收的是一個隨機噪聲,並生成圖像。

在訓練的過程當中,G 的目標是儘量生成足夠真實的數據去迷惑 D,而 D 就是要將 G 生成的圖片都辨別出來,這樣二者就是互相博弈,最終是要達到一個平衡,也就是納什均衡。

2. 優勢

(如下優勢和缺點主要來自 Ian Goodfellow 在 Quora 上的回答,以及知乎上的回答)

  • GAN 模型只用到了反向傳播,而不須要馬爾科夫鏈
  • 訓練時不須要對隱變量作推斷
  • 理論上,只要是可微分函數均可以用於構建 D 和 G ,由於可以與深度神經網絡結合作深度生成式模型
  • G 的參數更新不是直接來自數據樣本,而是使用來自 D 的反向傳播
  • 相比其餘生成模型(VAE、玻爾茲曼機),能夠生成更好的生成樣本
  • GAN 是一種半監督學習模型,對訓練集不須要太多有標籤的數據;
  • 沒有必要遵循任何種類的因子分解去設計模型,全部的生成器和鑑別器均可以正常工做

3. 缺點

  • 可解釋性差,生成模型的分佈 Pg(G)沒有顯式的表達
  • 比較難訓練, D 與 G 之間須要很好的同步,例如 D 更新 k 次而 G 更新一次
  • 訓練 GAN 須要達到納什均衡,有時候能夠用梯度降低法作到,有時候作不到.咱們尚未找到很好的達到納什均衡的方法,因此訓練 GAN 相比 VAE 或者 PixelRNN 是不穩定的,但我認爲在實踐中它仍是比訓練玻爾茲曼機穩定的多.
  • 它很難去學習生成離散的數據,就像文本
  • 相比玻爾茲曼機,GANs 很難根據一個像素值去猜想另一個像素值,GANs 天生就是作一件事的,那就是一次產生全部像素,你能夠用 BiGAN 來修正這個特性,它能讓你像使用玻爾茲曼機同樣去使用 Gibbs 採樣來猜想缺失值
  • 訓練不穩定,G 和 D 很難收斂;
  • 訓練還會遭遇梯度消失、模式崩潰的問題
  • 缺少比較有效的直接可觀的評估模型生成效果的方法

3.1 爲何訓練會出現梯度消失和模式奔潰

GAN 的本質就是 G 和 D 互相博弈並最終達到一個納什平衡點,但這只是一個理想的狀況,正常狀況是容易出現一方強大另外一方弱小,而且一旦這個關係造成,而沒有及時找到方法平衡,那麼就會出現問題了。而梯度消失和模式奔潰其實就是這種狀況下的兩個結果,分別對應 D 和 G 是強大的一方的結果。

首先對於梯度消失的狀況是D 越好,G 的梯度消失越嚴重,由於 G 的梯度更新來自 D,而在訓練初始階段,G 的輸入是隨機生成的噪聲,確定不會生成很好的圖片,D 會很容易就判斷出來真假樣本,也就是 D 的訓練幾乎沒有損失,也就沒有有效的梯度信息回傳給 G 讓 G 去優化本身。這樣的現象叫作 gradient vanishing,梯度消失問題。

其次,對於模式奔潰(mode collapse)問題,主要就是 G 比較強,致使 D 不能很好區分出真實圖片和 G 生成的假圖片,而若是此時 G 其實還不能徹底生成足夠真實的圖片的時候,但 D 卻分辨不出來,而且給出了正確的評價,那麼 G 就會認爲這張圖片是正確的,接下來就繼續這麼輸出這張或者這些圖片,而後 D 仍是給出正確的評價,因而二者就是這麼相互欺騙,這樣 G 其實就只會輸出固定的一些圖片,致使的結果除了生成圖片不夠真實,還有就是多樣性不足的問題。

更詳細的解釋能夠參考 使人拍案叫絕的Wasserstein GAN,這篇文章更詳細解釋了原始 GAN 的問題,主要就是出如今 loss 函數上。

3.2 爲何GAN不適合處理文本數據

  1. 文本數據相比較圖片數據來講是離散的,由於對於文原本說,一般須要將一個詞映射爲一個高維的向量,最終預測的輸出是一個one-hot向量,假設 softmax 的輸出是(0.2, 0.3, 0.1,0.2,0.15,0.05),那麼變爲 onehot是(0,1,0,0,0,0),若是softmax輸出是(0.2, 0.25, 0.2, 0.1,0.15,0.1 ),one-hot 仍然是(0, 1, 0, 0, 0, 0),因此對於生成器來講,G 輸出了不一樣的結果, 可是 D 給出了一樣的判別結果,並不能將梯度更新信息很好的傳遞到 G 中去,因此 D 最終輸出的判別沒有意義。
  2. GAN 的損失函數是 JS 散度,JS 散度不適合衡量不想交分佈之間的距離。(WGAN 雖然使用 wassertein 距離代替了 JS 散度,可是在生成文本上能力仍是有限,GAN 在生成文本上的應用有 seq-GAN,和強化學習結合的產物)

3.3 爲何GAN中的優化器不經常使用SGD

  1. SGD 容易震盪,容易使 GAN 的訓練更加不穩定,
  2. GAN 的目的是在高維非凸的參數空間中找到納什均衡點,GAN 的納什均衡點是一個鞍點,可是 SGD 只會找到局部極小值,由於 SGD 解決的是一個尋找最小值的問題,但 GAN 是一個博弈問題。

對於鞍點,來自百度百科的解釋是:

鞍點(Saddle point)在微分方程中,沿着某一方向是穩定的,另外一條方向是不穩定的奇點,叫作鞍點。在泛函中,既不是極大值點也不是極小值點的臨界點,叫作鞍點。在矩陣中,一個數在所在行中是最大值,在所在列中是最小值,則被稱爲鞍點。在物理上要普遍一些,指在一個方向是極大值,另外一個方向是極小值的點。

鞍點和局部極小值點、局部極大值點的區別以下圖所示:

局部極小值點和鞍點的對比

4. 訓練的技巧

訓練的技巧主要來自Tips and tricks to make GANs work

1. 對輸入進行規範化
  • 將輸入規範化到 -1 和 1 之間
  • G 的輸出層採用Tanh激活函數
2. 採用修正的損失函數

在原始 GAN 論文中,損失函數 G 是要 min (log(1-D)), 但實際使用的時候是採用 max(logD),做者給出的緣由是前者會致使梯度消失問題。

但實際上,即使是做者提出的這種實際應用的損失函數也是存在問題,即模式奔潰的問題,在接下來提出的 GAN 相關的論文中,就有很多論文是針對這個問題進行改進的,如 WGAN 模型就提出一種新的損失函數。

3. 從球體上採樣噪聲
  • 不要採用均勻分佈來採樣
  • 從高斯分佈中採樣獲得隨機噪聲
  • 當進行插值操做的時候,從大圓進行該操做,而不要直接從點 A 到 點 B 直線操做,以下圖所示

4. BatchNorm
  • 採用 mini-batch BatchNorm,要保證每一個 mini-batch 都是一樣的真實圖片或者是生成圖片
  • 不採用 BatchNorm 的時候,能夠採用 instance normalization(對每一個樣本的規範化操做)
  • 可使用虛擬批量歸一化(virtural batch normalization):開始訓練以前預約義一個 batch R,對每個新的 batch X,都使用 R+X 的級聯來計算歸一化參數
5. 避免稀疏的梯度:Relus、MaxPool
  • 稀疏梯度會影響 GAN 的穩定性
  • 在 G 和 D 中採用 LeakyReLU 代替 Relu 激活函數
  • 對於下采樣操做,能夠採用平均池化(Average Pooling) 和 Conv2d+stride 的替代方案
  • 對於上採樣操做,可使用 PixelShuffle(arxiv.org/abs/1609.05…), ConvTranspose2d + stride
6. 標籤的使用
  • 標籤平滑。也就是若是有兩個目標標籤,假設真實圖片標籤是 1,生成圖片標籤是 0,那麼對每一個輸入例子,若是是真實圖片,採用 0.7 到 1.2 之間的一個隨機數字來做爲標籤,而不是 1;通常是採用單邊標籤平滑
  • 在訓練 D 的時候,偶爾翻轉標籤
  • 有標籤數據就儘可能使用標籤
7. 使用 Adam 優化器
8. 儘早追蹤失敗的緣由
  • D 的 loss 變成 0,那麼這就是訓練失敗了
  • 檢查規範的梯度:若是超過 100,那出問題了
  • 若是訓練正常,那麼 D loss 有低方差而且隨着時間下降
  • 若是 g loss 穩定降低,那麼它是用糟糕的生成樣本欺騙了 D
9. 不要經過統計學來平衡 loss
10. 給輸入添加噪聲
11. 對於 Conditional GANs 的離散變量
  • 使用一個 Embedding 層
  • 對輸入圖片添加一個額外的通道
  • 保持 embedding 低維並經過上採樣操做來匹配圖像的通道大小
12 在 G 的訓練和測試階段使用 Dropouts
  • 以 dropout 的形式提供噪聲(50%的機率)
  • 訓練和測試階段,在 G 的幾層使用
  • arxiv.org/pdf/1611.07…

參考文章:

注:配圖來自網絡和參考文章


以上就是本文的主要內容和總結,能夠留言給出你對本文的建議和見解。

同時也歡迎關注個人微信公衆號--機器學習與計算機視覺或者掃描下方的二維碼,和我分享你的建議和見解,指正文章中可能存在的錯誤,你們一塊兒交流,學習和進步!

相關文章
相關標籤/搜索