GAN的原理入門

開發者自述:我是這樣學習 GAN 的

from:https://www.leiphone.com/news/201707/1JEkcUZI1leAFq5L.html  

Generative Adversarial Network,就是你們耳熟能詳的 GAN,由 Ian Goodfellow 首先提出,在這兩年更是深度學習中最熱門的東西,彷彿什麼東西都能由 GAN 作出來。我最近剛入門 GAN,看了些資料,作一些筆記。html

1.Generation

什麼是生成(generation)?就是模型經過學習一些數據,而後生成相似的數據。讓機器看一些動物圖片,而後本身來產生動物的圖片,這就是生成。網絡

之前就有不少能夠用來生成的技術了,好比 auto-encoder(自編碼器),結構以下圖:iphone

開發者自述:我是這樣學習 GAN 的

你訓練一個 encoder,把 input 轉換成 code,而後訓練一個 decoder,把 code 轉換成一個 image,而後計算獲得的 image 和 input 之間的 MSE(mean square error),訓練完這個 model 以後,取出後半部分 NN Decoder,輸入一個隨機的 code,就能 generate 一個 image。機器學習

可是 auto-encoder 生成 image 的效果,固然看着很彆扭啦,一眼就能看出真假。因此後來還提出了好比VAE這樣的生成模型,我對此也不是很瞭解,在這就不細說。函數

上述的這些生成模型,其實有一個很是嚴重的弊端。好比 VAE,它生成的 image 是但願和 input 越類似越好,可是 model 是如何來衡量這個類似呢?model 會計算一個 loss,採用的大可能是 MSE,即每個像素上的均方差。loss 小真的表示類似嘛?學習

開發者自述:我是這樣學習 GAN 的

好比這兩張圖,第一張,咱們認爲是好的生成圖片,第二張是差的生成圖片,可是對於上述的 model 來講,這兩張圖片計算出來的 loss 是同樣大的,因此會認爲是同樣好的圖片。編碼

這就是上述生成模型的弊端,用來衡量生成圖片好壞的標準並不能很好的完成想要實現的目的。因而就有了下面要講的 GAN。3d

2.GAN

大名鼎鼎的 GAN 是如何生成圖片的呢?首先你們都知道 GAN 有兩個網絡,一個是 generator,一個是 discriminator,從二人零和博弈中受啓發,經過兩個網絡互相對抗來達到最好的生成效果。流程以下:code

開發者自述:我是這樣學習 GAN 的

主要流程相似上面這個圖。首先,有一個一代的 generator,它能生成一些不好的圖片,而後有一個一代的 discriminator,它能準確的把生成的圖片,和真實的圖片分類,簡而言之,這個 discriminator 就是一個二分類器,對生成的圖片輸出 0,對真實的圖片輸出 1。視頻

接着,開始訓練出二代的 generator,它能生成稍好一點的圖片,可以讓一代的 discriminator 認爲這些生成的圖片是真實的圖片。而後會訓練出一個二代的 discriminator,它能準確的識別出真實的圖片,和二代 generator 生成的圖片。以此類推,會有三代,四代。。。n 代的 generator 和 discriminator,最後 discriminator 沒法分辨生成的圖片和真實圖片,這個網絡就擬合了。

這就是 GAN,運行過程就是這麼的簡單。這就結束了嘛?顯然沒有,下面還要介紹一下 GAN 的原理。

3.原理

首先咱們知道真實圖片集的分佈 Pdata(x),x 是一個真實圖片,能夠想象成一個向量,這個向量集合的分佈就是 Pdata。咱們須要生成一些也在這個分佈內的圖片,若是直接就是這個分佈的話,怕是作不到的。

咱們如今有的 generator 生成的分佈能夠假設爲 PG(x;θ),這是一個由 θ 控制的分佈,θ 是這個分佈的參數(若是是高斯混合模型,那麼 θ 就是每一個高斯分佈的平均值和方差)

假設咱們在真實分佈中取出一些數據,{x1, x2, ... , xm},咱們想要計算一個似然 PG(xi; θ)。

對於這些數據,在生成模型中的似然就是

開發者自述:我是這樣學習 GAN 的

咱們想要最大化這個似然,等價於讓 generator 生成那些真實圖片的機率最大。這就變成了一個最大似然估計的問題了,咱們須要找到一個 θ* 來最大化這個似然。

開發者自述:我是這樣學習 GAN 的

尋找一個 θ* 來最大化這個似然,等價於最大化 log 似然。由於此時這 m 個數據,是從真實分佈中取的,因此也就約等於,真實分佈中的全部 x 在 P分佈中的 log 似然的指望。

真實分佈中的全部 x 的指望,等價於求機率積分,因此能夠轉化成積分運算,由於減號後面的項和 θ 無關,因此添上以後仍是等價的。而後提出共有的項,括號內的反轉,max 變 min,就能夠轉化爲 KL divergence 的形式了,KL divergence 描述的是兩個機率分佈之間的差別。

因此最大化似然,讓 generator 最大機率的生成真實圖片,也就是要找一個 θ 讓 P更接近於 Pdata。

那如何來找這個最合理的 θ 呢?咱們能夠假設 PG(x; θ) 是一個神經網絡。

首先隨機一個向量 z,經過 G(z)=x 這個網絡,生成圖片 x,那麼咱們如何比較兩個分佈是否類似呢?只要咱們取一組 sample z,這組 z 符合一個分佈,那麼經過網絡就能夠生成另外一個分佈 PG,而後來比較與真實分佈 Pdata。

你們都知道,神經網絡只要有非線性激活函數,就能夠去擬合任意的函數,那麼分佈也是同樣,因此能夠用一直正態分佈,或者高斯分佈,取樣去訓練一個神經網絡,學習到一個很複雜的分佈。

開發者自述:我是這樣學習 GAN 的

如何來找到更接近的分佈,這就是 GAN 的貢獻了。先給出 GAN 的公式:

開發者自述:我是這樣學習 GAN 的

這個式子的好處在於,固定 G,max  V(G,D) 就表示 PG 和 Pdata 之間的差別,而後要找一個最好的 G,讓這個最大值最小,也就是兩個分佈之間的差別最小。

開發者自述:我是這樣學習 GAN 的

表面上看這個的意思是,D 要讓這個式子儘量的大,也就是對於 x 是真實分佈中,D(x) 要接近與 1,對於 x 來自於生成的分佈,D(x) 要接近於 0,而後 G 要讓式子儘量的小,讓來自於生成分佈中的 x,D(x) 儘量的接近 1。

如今咱們先固定 G,來求解最優的 D:

開發者自述:我是這樣學習 GAN 的

開發者自述:我是這樣學習 GAN 的

對於一個給定的 x,獲得最優的 D 如上圖,範圍在 (0,1) 內,把最優的 D 帶入

開發者自述:我是這樣學習 GAN 的

能夠獲得:

開發者自述:我是這樣學習 GAN 的

開發者自述:我是這樣學習 GAN 的

JS divergence 是 KL divergence 的對稱平滑版本,表示了兩個分佈之間的差別,這個推導就代表了上面所說的,固定 G。

開發者自述:我是這樣學習 GAN 的

表示兩個分佈之間的差別,最小值是 -2log2,最大值爲 0。

如今咱們須要找個 G,來最小化

開發者自述:我是這樣學習 GAN 的

觀察上式,當 PG(x)=Pdata(x) 時,G 是最優的。

4.訓練

有了上面推導的基礎以後,咱們就能夠開始訓練 GAN 了。結合咱們開頭說的,兩個網絡交替訓練,咱們能夠在起初有一個 G0 和 D0,先訓練 D找到 :

開發者自述:我是這樣學習 GAN 的

而後固定 D0 開始訓練 G0, 訓練的過程均可以使用 gradient descent,以此類推,訓練 D1,G1,D2,G2,...

可是這裏有個問題就是,你可能在 D0* 的位置取到了:

開發者自述:我是這樣學習 GAN 的

而後更新 G0 爲 G1,可能

開發者自述:我是這樣學習 GAN 的

了,可是並不保證會出現一個新的點 D1* 使得

開發者自述:我是這樣學習 GAN 的

這樣更新 G 就沒達到它原來應該要的效果,以下圖所示:

開發者自述:我是這樣學習 GAN 的

避免上述狀況的方法就是更新 G 的時候,不要更新 G 太多。

知道了網絡的訓練順序,咱們還須要設定兩個 loss function,一個是 D 的 loss,一個是 G 的 loss。下面是整個 GAN 的訓練具體步驟:

開發者自述:我是這樣學習 GAN 的

上述步驟在機器學習和深度學習中也是很是常見,易於理解。

5.存在的問題

可是上面 G 的 loss function 仍是有一點小問題,下圖是兩個函數的圖像:

開發者自述:我是這樣學習 GAN 的

log(1-D(x)) 是咱們計算時 G 的 loss function,可是咱們發現,在 D(x) 接近於 0 的時候,這個函數十分平滑,梯度很是的小。這就會致使,在訓練的初期,G 想要騙過 D,變化十分的緩慢,而上面的函數,趨勢和下面的是同樣的,都是遞減的。可是它的優點是在 D(x) 接近 0 的時候,梯度很大,有利於訓練,在 D(x) 愈來愈大以後,梯度減少,這也很符合實際,在初期應該訓練速度更快,到後期速度減慢。

因此咱們把 G 的 loss function 修改成

開發者自述:我是這樣學習 GAN 的

這樣能夠提升訓練的速度。

還有一個問題,在其餘 paper 中提出,就是通過實驗發現,通過許屢次訓練,loss 一直都是平的,也就是

開發者自述:我是這樣學習 GAN 的

JS divergence 一直都是 log2,P和 Pdata 徹底沒有交集,可是實際上兩個分佈是有交集的,形成這個的緣由是由於,咱們沒法真正計算指望和積分,只能使用 sample 的方法,若是訓練的過擬合了,D 仍是可以徹底把兩部分的點分開,以下圖:

開發者自述:我是這樣學習 GAN 的

對於這個問題,咱們是否應該讓 D 變得弱一點,減弱它的分類能力,可是從理論上講,爲了讓它可以有效的區分真假圖片,咱們又但願它可以 powerful,因此這裏就產生了矛盾。

還有可能的緣由是,雖然兩個分佈都是高維的,可是兩個分佈都十分的窄,可能交集至關小,這樣也會致使 JS divergence 算出來 =log2,約等於沒有交集。

解決的一些方法,有添加噪聲,讓兩個分佈變得更寬,可能能夠增大它們的交集,這樣 JS divergence 就能夠計算,可是隨着時間變化,噪聲須要逐漸變小。

還有一個問題叫 Mode Collapse,以下圖:

開發者自述:我是這樣學習 GAN 的

這個圖的意思是,data 的分佈是一個雙峯的,可是學習到的生成分佈卻只有單峯,咱們能夠看到模型學到的數據,可是殊不知道它沒有學到的分佈。

形成這個狀況的緣由是,KL divergence 裏的兩個分佈寫反了

開發者自述:我是這樣學習 GAN 的

這個圖很清楚的顯示了,若是是第一個 KL divergence 的寫法,爲了防止出現無窮大,因此有 Pdata 出現的地方都必需要有 PG 覆蓋,就不會出現 Mode Collapse。

6.參考

這是對 GAN 入門學習作的一些筆記和理解,後來太懶了,不想打公式了,主要是參考了李宏毅老師的視頻:

http://t.cn/RKXQOV0

相關文章
相關標籤/搜索