文章目錄算法
1.2.2 重說GAN原理機器學習
1.2.3 小結函數
生成對抗網絡(Generative Adversarial Networks, GANs)是由Ian Goodfellow等人在2014年的Generative Adversarial Networks一文中提出。Facebook的人工智能主管Yann Lecun對其的評價是:「機器學習在過去10年中最有趣的想法」。**GANs的潛力巨大,由於它們能夠學習模仿任何數據分佈。**也就是說,GANs通過學習後,能夠創造出相似於咱們真實世界的一些東西,好比:圖像、音樂、散文等等。從某種意義來講,它們是「機器人藝術家」,有些確實可以讓人印象深入。
基於GAN,能夠將人臉粘貼到視頻中的目標人物上
在講GANGAN以前,先講一個小趣事,你知道GANGAN是怎麼被髮明的嗎?據IanIan GoodfellowGoodfellow本身說: 以前他一直在研究生成模型,多是一時興起,有一天他在酒吧喝酒時,在酒吧裏跟朋友討論起生成模型。而後IanIan GoodfellowGoodfellow想到GANGAN的思想,跟朋友說你應該這麼作這麼作這麼作,我打賭必定會有用。可是朋友不信,因而他直接從酒吧回去開始作實驗,一夜就寫出了GANGAN論文~
這個故事告訴咱們,喝酒,不只能打醉拳,也能寫出頂級論文…
故事講完,開幹(GANGAN)吧:GANGAN包含有兩個模型,一個是生成模型(generativegenerative modelmodel),一個是判別模型(discriminativediscriminative modelmodel)。生成模型的任務是生成看起來天然真實的、和原始數據類似的數據。判別模型的任務是判斷給定的實例看起來是天然真實的仍是認爲僞造的(真實實例來源於數據集,僞造實例來源於生成模型)。
這能夠看作一種零和遊戲。論文采用類比的手法通俗理解:生成模型像「一個造假團伙,試圖生產和使用假幣」,而判別模型像「檢測假幣的警察」。生成器(generatorgenerator)試圖欺騙判別器(discriminatordiscriminator),判別器則努力不被生成器欺騙。模型通過交替優化訓練,兩種模型都能獲得提高,但最終咱們要獲得的是效果提高到很高很好的生成模型(造假團伙),這個生成模型(造假團伙)所生成的產品能達到真假難分的地步,這個過程就如上圖的對抗過程。
隨着學術界和工業界都開始接收並歡迎GANGAN的到來,GANGAN的崛起不可避免:
首先,GANGAN最厲害的地方是它的學習性質是無監督的。GANGAN也不須要標記數據,這使GANGAN功能強大,由於數據標記的工做很是枯燥。
其次,GANGAN的潛在用例使它成爲交談的中心。它能夠生成高質量的圖像,圖片加強,從文本生成圖像,將圖像從一個域轉換爲另外一個域,隨年齡增加改變臉部外觀等等。這個名單較長並還在快速增加。
第三,圍繞GANGAN不斷的研究是如此使人着迷,以致於它吸引了其餘(圖像以外)全部行業的注意力。
要全面理解生成對抗網絡,首先要理解的概念是監督式學習和非監督式學習。監督式學習是指基於大量帶有標籤的訓練集與測試集的機器學習過程,好比監督式圖片分類器須要一系列圖片和對應的標籤(「貓」,「狗」,…),而非監督式學習則不須要這麼多額外的工做,它們能夠本身從錯誤中進行學習,並下降將來出錯的機率。監督式學習的缺點就是須要大量標籤樣本,這很是耗時耗力。非監督式學習雖然沒有這個問題,但準確率每每更低。天然而然地但願可以經過提高非監督式學習的性能,從而減小對監督式學習的依賴。GANGAN能夠說是對於非監督式學習的一種提高。
第二個須要理解的概念是「生成模型」, 以下圖所示生成圖片模型的概念示意圖。這類模型可以經過輸入的樣本產生可能的輸出。舉個例子,一個生成模型能夠經過視頻的某一幀預測出下一幀的輸出。另外一個例子是搜索引擎,在你輸入的同時,搜索引擎已經在推斷你可能搜索的內容了。
基於上面的兩個概念就能夠設計生成對抗網絡GANGAN了。相比於傳統的神經網絡模型,GANGAN是一種全新的非監督式的架構(以下圖所示)。GANGAN包括了兩套獨立的網絡,二者之間做爲互相對抗的目標。第一套網絡是咱們須要訓練的分類器(下圖中的D),用來分辨是不是真實數據仍是虛假數據;第二套網絡是生成器(下圖中的G),生成相似於真實樣本的隨機樣本,並將其做爲假樣本。
D做爲一個圖片分類器,對於一系列圖片區分不一樣的動物。生成器G的目標是繪製出很是接近的僞造圖片來欺騙D,作法是選取訓練數據潛在空間中的元素進行組合,並加入隨機噪音,例如在這裏能夠選取一個貓的圖片,而後給貓加上第三隻眼睛,以此做爲假數據。
在訓練過程當中,D會接收真數據和G產生的假數據,它的任務是判斷圖片是屬於真數據的仍是假數據的。對於最後輸出的結果,能夠同時對兩方的參數進行調優。若是D判斷正確,那就須要調整G的參數從而使得生成的假數據更爲逼真;若是D判斷錯誤,則需調節D的參數,避免下次相似判斷出錯。訓練會一直持續到二者進入到一個均衡和諧的狀態。
訓練後的產物是一個質量較高的自動生成器和一個判斷能力較強強的分類器。前者能夠用於機器創做(自動畫出「貓」「狗」),然後者則能夠用來機器分類(自動判斷「貓」「狗」)。
小結:GANGAN算法流程簡述
算法流程簡述
初始化generator和discriminator。
每一次的迭代過程當中:
固定generator, 只更新discriminator的參數。從你準備的數據集中隨機選擇一些,再從generator的output中選擇一些,如今等於discriminator有兩種input。接下來, discriminator的學習目標是, 若是輸入是來自於真實數據集,則給高分;若是是generator產生的數據,則給低分,能夠把它當作一個迴歸問題。
接下來,固定住discriminator的參數, 更新generator。將一個向量輸入generator, 獲得一個output, 將output扔進discriminator, 而後會獲得一個分數,這一階段discriminator的參數已經固定住了,generator須要調整本身的參數使得這個output的分數越大越好。
按這個過程聽起來好像有兩個網絡,而實際過程當中,generator和discriminator是同一個網絡,只不過網絡中間的某一層hidden-layer的輸出是一個圖片(或者語音,取決於你的數據集)。在訓練的時候也是固定一部分hidden-layer,調其他的hidden-layer。
如下這一段真的是太枯燥了,純屬爲了內容完整性,不喜跳過… 一點不影響全文理解,哈哈哈~
------------------------------------------------------------高能開始分割線------------------------------------------------------------------------------
考慮一下,GAN到底生成的是什麼呢?好比說,假如咱們想要生成一些人臉圖,實際上,咱們是想找到一個分佈,從這個分部內sample出來的圖片,像是人臉,而不屬於這個distribution的分佈,生成的就不是人臉。而GAN要作的就是找到這個distribution。
在GAN出生以前,咱們怎麼作這個事情呢?
以前用的是Maximum Likelihood Estimation,最大似然估計來作生成的,咱們先從機率分佈及參數估計提及,經過介紹極大似然估計、KL散度、JS散度,再詳細介紹GAN生成對抗網絡的數學原理。
不管是黑白圖片或彩色圖片, 都是使用 0 ~ 255 的數值表示像素. 將全部的像素值除以 255 咱們就能夠將一張圖片轉化爲 0 ~ 1 的機率分佈, 並且這種轉化是可逆的, 乘以 255 就能夠還原.
從某種意義上來說, GAN 圖片生成任務就是生成機率分佈. 所以, 咱們有必要結合機率分佈來理解 GAN 生成對抗網絡的原理.
回顧機率分佈及參數估計
先來看一個例子:
假設一個抽獎盒子裏有45個球,其編號是1-9共9個數字。每一個編號的球擁有的數量是:
編號123456789
數量246897531
佔比0.0440.0880.1330.1780.2000.1560.1110.0660.022
佔比是指用每一個編號的數量除以全部編號的數量總和,在數理統計中,在不引發誤會的狀況下,這裏的佔比也能夠被稱爲機率/頻率。
使用向量qq表示上述的機率分佈:
\begin{aligned}q &=(1,4,6,8,9,7,5,3,1)/45 \\&=(0.044,0.088,0.133,0.178,0.200,0.156,0.111,0.066,0.022) \end{aligned}q=(1,4,6,8,9,7,5,3,1)/45=(0.044,0.088,0.133,0.178,0.200,0.156,0.111,0.066,0.022)
將上述分佈使用圖像繪製以下:
如今咱們但願構建一個函數p=p(x;\theta)p=p(x;θ),以xx爲編號做爲輸入數據,輸出編號xx的機率。\thetaθ是參與構建這個函數的參數,一經選定就再也不變化。
假設上述機率分佈服從二次拋物線函數:
\begin{aligned}p &=p(x;\theta) \\ &=\theta_1(x+\theta_2)^2+\theta_3 \\ \end{aligned}p=p(x;θ)=θ1(x+θ2)2+θ3
x=(1,2,3,4,5,6,7,8,9)x=(1,2,3,4,5,6,7,8,9)
使用L_2L2偏差做爲評價擬合效果的損失函數,總偏差值爲error(標量e)error(標量e):
e=\Sigma_{i=1}^9(p_i-q_i)^2e=Σi=19(pi−qi)2
咱們但願求得一個\theta^*θ∗,使得ee的值越小越好,數學上表達爲:
\theta^*=\mathop {argmin}_{\theta}(e)θ∗=argminθ(e)
argmin是argument minimum的縮寫。
如何求\theta^*θ∗不是本文的重點,這是生成對抗網絡的任務。爲了幫助理解,取其中一個可能的數值做爲示例:
\theta^*=(\theta_1,\theta_2,\theta_3)=(-0.01,-5.0,0.2)θ∗=(θ1,θ2,θ3)=(−0.01,−5.0,0.2)
p=p(x;\theta)=-0.01(x-5.0)^2+0.2p=p(x;θ)=−0.01(x−5.0)2+0.2
繪製函數圖像以下:
在生成對抗網絡中,本例的估計函數p(x:\theta)p(x:θ)至關於生成模型(generator),損失函數至關於鑑別模型(discriminator)。
從最大似然估計講起
最大似然估計的理念是:假如說咱們的數據集分佈是P_{data}(x)Pdata(x),咱們定義一個分佈P_G(x;\theta)PG(x;θ),咱們想要找到一組參數\thetaθ,使得P_G(x;\theta)PG(x;θ)越接近P_{data}(x)越好Pdata(x)越好。例如,若P_G(x;\theta)PG(x;θ)是一個高斯混合模型,那麼\thetaθ就是均值和方差。
具體怎麼操做呢
從P_{data}(x)Pdata(x)中採樣出{x^1,x^2,x^3,\dots,x^m}x1,x2,x3,…,xm;
對每個採樣出來的xx,咱們均可以計算出它的似然函數,也就是能夠獲得一組參數\thetaθ,進而就能知道P_G(x;\theta)PG(x;θ)長什麼樣,而後就能夠進一步計算出這個分佈裏面採樣出的某一個xx的概率;
把在某個分佈能夠產生x_ixi的參數似然函數乘起來,能夠獲得總的似然函數:
L=\prod_{i=1}^mP_G(x^i;\theta)L=∏i=1mPG(xi;θ)
咱們要找到一組\theta^*θ∗,能夠最大化LL。
在上面的例子中,咱們很幸運的知道了全部可能的機率分佈,並讓求解最優化的機率分佈估計函數p(x;\theta)p(x;θ)成爲可能。
若是上例的抽獎盒子(樣本)中的45個球是從更大的抽獎池(整體)中選出來的,而咱們不知道抽獎池中全部球的數量及其編號。那麼,咱們如何根據現有的45個球來估計抽獎池的機率分佈呢?固然,咱們能夠直接用上例求得的樣本估計函數來表明抽獎池的機率分佈,但接下來會介紹一種更爲經常使用的估計方法,即本節開篇提到的最大似然估計。
假設p(x)=p(x;\theta)p(x)=p(x;θ)是整體的機率分佈函數,則編號x=(x_1,x_2,x_3,\cdots,x_n)x=(x1,x2,x3,⋯,xn)出現的機率爲:
p=p(x_1),p(x_2),p(x_3),\cdots,p(x_n)p=p(x1),p(x2),p(x3),⋯,p(xn)
在本例中,n=9n=9,即共9個編號。
設d=(d_1,d_2,d_3,\cdots,d_m)d=(d1,d2,d3,⋯,dm)是全部抽樣的編號,在本例中,m=45m=45,即樣本中共有45個抽樣。假設全部的樣本和抽樣都是獨立的,則樣本出現的機率爲:
\rho=p(d_1)\times p(d_2)\times p(d_3)\times \cdots \times p(d_m)=\prod_{i=1}^m(p(d_i))ρ=p(d1)×p(d2)×p(d3)×⋯×p(dm)=∏i=1m(p(di))
p(x)=p(x;\theta)p(x)=p(x;θ)的函數結構是人爲按經驗選取的,好比線性函數,多元二次函數,更復雜的非線性函數等,一經選取則再也不改變。如今咱們須要求解一個參數集\theta^*θ∗,使得\rhoρ的值越大越好。即
\theta^*=\mathop {argmax}_\theta(\rho)=\mathop {argmax}_\theta\prod_{i=1}^mp(d_i;\theta)θ∗=argmaxθ(ρ)=argmaxθ∏i=1mp(di;θ)
argmax是argument maximum的縮寫。
通俗來說,由於樣本是實際已發生的事實,在函數結構已肯定的狀況下,咱們須要儘可能優化參數,使得樣本的理論估計機率越大越好。
這裏有一個前提,就是認爲選定的函數結構應當可以有效評估樣本分佈。反之,若是使用線性函數去擬合正態機率分佈(normal distribution),則不管如何選擇參數都沒法獲得滿意的效果。
連乘運算不方便,將之改成求和運算。因爲loglog對數函數的單調性,上面的式子等價於:
\theta^*=\mathop {argmax}_\theta log\prod_{i=1}^mp(d_i;\theta)=\mathop {argmax}_\theta\Sigma_{i=1}^mlog \, p(d_i;\theta)θ∗=argmaxθlog∏i=1mp(di;θ)=argmaxθΣi=1mlogp(di;θ)
設樣本分佈爲q(x)q(x),對於給定樣本,這個分佈是已知的,能夠經過統計抽樣的計算得出,將上式轉化成指望公式:
\theta^*=\mathop {argmax}_\theta\Sigma_{i=1}^mlog\,p(d_i;\theta)=\mathop {argmax}_\theta\Sigma_{i=1}^n q(x_i)log\,p(x_i;\theta)θ∗=argmaxθΣi=1mlogp(di;θ)=argmaxθΣi=1nq(xi)logp(xi;θ)
注意上式中的兩個求和符號,mm變成了nn。在大多數狀況下,編號數量會比抽樣數量少,轉爲指望公式能夠顯著減小計算量。
在一些教程中,上式寫法爲:
\theta^*=\mathop {argmax}_\theta E_{x-q(x)}log\,p(x;\theta)=\mathop {argmax}_\theta \int q(x)log\,p(x;\theta)dxθ∗=argmaxθEx−q(x)logp(x;θ)=argmaxθ∫q(x)logp(x;θ)dx
E_{x-q(x)}Ex−q(x)表示按q(x)q(x)的分佈對xx求指望。由於積分表達式比較簡潔,書寫方便,下文開始將主要使用積分表達式。
以上就是最大似然估計(Maximum Likelihood Estimation)的理論和推導過程。和上例的參數估計方法相比,最大似然估計由於無需設計損失函數,下降了模型的複雜度,擴大了適用範圍。
本例中的估計函數p(x;\theta)p(x;θ)至關於生成對抗網絡的生成模型,樣本分佈q(x)q(x)至關於訓練數據。
另外一種解釋-KL散度
結合上例,在樣本已知的狀況下,q(x)q(x)是一個已知且肯定的分佈。則\int q(x)log\,q(x)dx∫q(x)logq(x)dx是一個常數項,不影響\theta^*θ∗求解的結果,則可添加項
\begin{aligned}\theta^* &=\mathop {argmax}_\theta(\int q(x)log\,p(x;\theta)dx-\int q(x)log\,q(x)dx) \\ &=\mathop {argmax}_\theta \int q(x)(log\,p(x;\theta)-log\,q(x))dx \\ &=\mathop {argmax}_\theta \int q(x)log\frac{p(x;\theta)}{q(x)}dx \end{aligned}θ∗=argmaxθ(∫q(x)logp(x;θ)dx−∫q(x)logq(x)dx)=argmaxθ∫q(x)(logp(x;θ)−logq(x))dx=argmaxθ∫q(x)logq(x)p(x;θ)dx
也能夠寫成這樣:
\begin{aligned}\theta^* &=\mathop {argmin}_\theta(-\int q(x)log\,p(x;\theta)dx+\int q(x)log\,q(x)dx) \\ &=\mathop {argmin}_\theta \int q(x)log\frac{q(x)}{p(x;\theta)}dx \end{aligned}θ∗=argminθ(−∫q(x)logp(x;θ)dx+∫q(x)logq(x)dx)=argminθ∫q(x)logp(x;θ)q(x)dx
KLKL散度(Kullback-Leibler divergence)是一種衡量兩個機率分佈的匹配程度的指標,兩個分佈差別越大,KL散度越大。它還有不少的名字,好比:relative entropy, relative information。
其定義以下:
D_{KL}(q\mid\mid p)=\int q(x)log\frac{q(x)}{p(x)}dxDKL(q∣∣p)=∫q(x)logp(x)q(x)dx
當p(x)=q(x)p(x)=q(x)時,取最小值D_{KL}(q\mid\mid p)=0DKL(q∣∣p)=0。
咱們能夠將上面的公式簡化爲:
\theta^*=\mathop {argmin}_\theta D_{KL}(q\mid\mid p(x;\theta))θ∗=argminθDKL(q∣∣p(x;θ))
KL散度的補充-JS散度
KLKL散度是非對稱的,即D_{KL}(q\mid\mid p)DKL(q∣∣p)不必定等於D_{KL}(p\mid\mid q)DKL(p∣∣q)。爲了解決這個問題,須要引入JSJS散度。JSJS散度(Jensen-Shannon divergence)的定義以下:
m=\frac{1}{2}(p+q)m=21(p+q)
D_{JS}=\frac{1}{2}D_{KL}(p\mid\mid m)+\frac{1}{2}D_{KL}(q\mid\mid m)DJS=21DKL(p∣∣m)+21DKL(q∣∣m)
JSJS的值域是對稱的,有界的,範圍是[0,1][0,1]。
若是p,\,qp,q徹底相同,則JS=0JS=0,若是徹底不相同,則JS=1JS=1。
注意,KLKL散度和JSJS散度做爲差別度量的時候,有一個問題:
若是兩個分配pp,qq離得很遠,徹底沒有重疊的時候,那麼KLKL散度值是沒有意義的,而JSJS散度值是一個常數。這在學習算法中是比較致命的,由於這意味着在這一點的梯度爲0,梯度消失了。
參考上述例子,對JSJS進行反推:
\begin{aligned}D_{JS}(q\mid\mid p) &=\frac{1}{2} D_{KL}(q\mid\mid m)+\frac{1}{2}D_{KL}(p\mid\mid m) \\ &=\frac{1}{2}\int q(x)log\frac{q(x)}{\frac{q(x)+p(x;\theta)}{2}}dx+\frac{1}{2}\int p(x;\theta)log\frac{p(x;\theta)}{\frac{p(x;\theta)+q(x)}{2}}dx \\ &=\frac{1}{2}\int q(x)log\frac{2q(x)}{q(x)+p(x;\theta)}dx+\frac{1}{2}\int p(x;\theta)log\frac{2p(x;\theta)}{p(x;\theta)+q(x)}dx \end{aligned}DJS(q∣∣p)=21DKL(q∣∣m)+21DKL(p∣∣m)=21∫q(x)log2q(x)+p(x;θ)q(x)dx+21∫p(x;θ)log2p(x;θ)+q(x)p(x;θ)dx=21∫q(x)logq(x)+p(x;θ)2q(x)dx+21∫p(x;θ)logp(x;θ)+q(x)2p(x;θ)dx
因爲:
\begin{aligned}\int q(x)log\frac{2q(x)}{q(x)+p(x;\theta)}dx &=\int q(x)(log\frac{q(x)}{q(x)+p(x;\theta)}+log2)dx \\ &=\int q(x)log\frac{q(x)}{q(x)+p(x;\theta)}dx+\int q(x)log2dx \\ &=\int q(x)log\frac{q(x)}{q(x)+p(x;\theta)}dx+log2 \end{aligned}∫q(x)logq(x)+p(x;θ)2q(x)dx=∫q(x)(logq(x)+p(x;θ)q(x)+log2)dx=∫q(x)logq(x)+p(x;θ)q(x)dx+∫q(x)log2dx=∫q(x)logq(x)+p(x;θ)q(x)dx+log2
同理可得:
\begin{aligned}D_{JS}(q\mid\mid p) &=\frac{1}{2}\int q(x)log\frac{q(x)}{q(x)+p(x;\theta)}dx+\frac{1}{2}\int p(x;\theta)log\frac{p(x;\theta)}{p(x;\theta)+q(x)}dx +log2 \end{aligned}DJS(q∣∣p)=21∫q(x)logq(x)+p(x;θ)q(x)dx+21∫p(x;θ)logp(x;θ)+q(x)p(x;θ)dx+log2
令:
d(x;\theta)=\frac{q(x)}{q(x)+p(x;\theta)}d(x;θ)=q(x)+p(x;θ)q(x)
則:
1-d(x;\theta)=\frac{p(x;\theta)}{q(x)+p(x;\theta)}1−d(x;θ)=q(x)+p(x;θ)p(x;θ)
即:
D_{JS}(q\mid\mid p)=\frac{1}{2}\int q(x)log\,d(x;\theta)dx+\frac{1}{2}\int p(x;\theta)log(1-d(x;\theta))dx+log2DJS(q∣∣p)=21∫q(x)logd(x;θ)dx+21∫p(x;θ)log(1−d(x;θ))dx+log2
令:
V(x;\theta)=\int q(x)log\,d(x;\theta)dx+\int p(x;\theta)log(1-d(x;\theta))dxV(x;θ)=∫q(x)logd(x;θ)dx+∫p(x;θ)log(1−d(x;θ))dx
則:
D_{JS}(q\mid\mid p)=\frac{1}{2}V(x;\theta)+log2DJS(q∣∣p)=21V(x;θ)+log2
即:
\theta^*=\mathop {argmin}_\theta D_{JS}(q\mid\mid p)=\mathop{argmin}_\theta V(x;\theta)θ∗=argminθDJS(q∣∣p)=argminθV(x;θ)
此時,\theta^*θ∗是令p(x;\theta)p(x;θ)和q(x)q(x)差別最小的參數,一樣亦可經過V(x;\theta)V(x;θ)求差別最大的參數。
JS散度參數求解的兩步走迭代方法
從上面的討論知道,咱們須要一個參數\theta^*θ∗,使得
\theta^*=\mathop {argmin}_\theta D_{JS}(q\mid\mid p)=\mathop {argmin}_\theta V(x;theta)θ∗=argminθDJS(q∣∣p)=argminθV(x;theta)
然而,由於涉及多重嵌套和積分,使用起來並不方便。
首先,咱們假設p(x;\theta)=p_g(x)p(x;θ)=pg(x)爲已知條件,同時令D=d(x;\theta)D=d(x;θ),考慮這個式子:
W(x;\theta)=q(x)log\,d(x;\theta)dx+p(x;\theta)log(1-d(x;\theta))W(x;θ)=q(x)logd(x;θ)dx+p(x;θ)log(1−d(x;θ))
W(x;D)=q(x)log\,D+p_g(x)log(1-D)W(x;D)=q(x)logD+pg(x)log(1−D)
V(x;\theta)=V(x;D)=\int W(x;D)dxV(x;θ)=V(x;D)=∫W(x;D)dx
在xx已知的狀況下,咱們關注DD。
W'=\frac{dW}{dD}=q(x)\frac{1}{D}-p_g(x)\frac{1}{1-D}W′=dDdW=q(x)D1−pg(x)1−D1
W''=\frac{dW'}{dD}=-q(x)\frac{1}{D^2}-p_g(x)\frac{1}{(1-D)^2}W′′=dDdW′=−q(x)D21−pg(x)(1−D)21
由於W''\lt0W′′<0,當W'=0W′=0時,WW取得極大值:
W'=q(x)\frac{1}{D}-p_g(x)\frac{1}{1-D}=0W′=q(x)D1−pg(x)1−D1=0
D=\frac{q(x)}{q(x)+p_g(x)}D=q(x)+pg(x)q(x)
由於:
D\lt\frac{q(x)}{q(x)+p_g(x)},\, W'\gt0D<q(x)+pg(x)q(x),W′>0
D\gt\frac{q(x)}{q(x)+p_g(x)},\,W'\lt0D>q(x)+pg(x)q(x),W′<0
這代表,當DD的函數按W'=0W′=0取值時,WW在xx的每一個取樣點均得到最大值,積分後的面積得到最大值,即:
D^*=\frac{q(x)}{q(x)+p_g(x)}=\mathop {argmax}_D\int W(x;D)dx=\mathop {argmax}_DV(x;D)D∗=q(x)+pg(x)q(x)=argmaxD∫W(x;D)dx=argmaxDV(x;D)
\mathop {max}_DV(x;D)=\int q(x)log\,D^*(x)dx+\int p_g(x)log(1-D^*(x))dxmaxDV(x;D)=∫q(x)logD∗(x)dx+∫pg(x)log(1−D∗(x))dx
在獲得V(x;D)V(x;D)的最大值表達式後,咱們固定D^*D∗,接着對p(x;\theta)=p_g(x)p(x;θ)=pg(x)將這個最大值按最小方向優化:
V(x;\theta;D^*)=\int q(x)log\,D^*(x)dx+\int p(x;\theta)log(1-D^*(x))dxV(x;θ;D∗)=∫q(x)logD∗(x)dx+∫p(x;θ)log(1−D∗(x))dx
\theta^*=\mathop {argmin}_\theta V(x;\theta^*;D^*)θ∗=argminθV(x;θ∗;D∗)
所以,經過兩步走的方法,通過屢次前後迭代求解D^*D∗和\theta^*θ∗,咱們能夠逐漸獲得一個趨近於q(x)q(x)的p(x;\theta^*)p(x;θ∗)。
生成對抗網絡
從上述的討論可知,咱們能夠獲得一個和q(x)q(x)很是接近的分佈函數p(x;\theta)p(x;θ)。這個分佈函數的構建是爲了尋找已知樣本數據的內在規律。
而後咱們每每並不關心這個分佈函數,咱們但願無中生有的構建一批數據x'x′,使得p(x';\theta)p(x′;θ)趨近於q(x)q(x)。
咱們設計一個輸出x'x′的生成器x'=G(z;\beta)x′=G(z;β),從隨機機率分佈中接收zz做爲輸入,x'x′的機率分佈爲p_g(x')pg(x′)。
第一步,咱們固定p_g(x')pg(x′),求D^*D∗:
V(x,x';D)=\int q(x)log\,D(x)dx+\int p_g(x')log(1-D(x'))dxV(x,x′;D)=∫q(x)logD(x)dx+∫pg(x′)log(1−D(x′))dx
D^*=\mathop {argmax}_DV(x;D)D∗=argmaxDV(x;D)
第二步,咱們固定D*D∗,求p_g(x';\theta^*)pg(x′;θ∗):
V(x,x',D^*;\theta)=\int q(x)log\,D^*(x)dx+\int p_g(x';\theta)log(1-D^*(x'))dxV(x,x′,D∗;θ)=∫q(x)logD∗(x)dx+∫pg(x′;θ)log(1−D∗(x′))dx
\theta^*=\mathop {argmin}_\theta V(x,D^*;\theta^*)θ∗=argminθV(x,D∗;θ∗)
而後進行屢次循環迭代,使得p_g(x';\theta^*)pg(x′;θ∗)趨近於q(x)q(x)。
仔細觀察能夠發現,這裏求解過程和上例的是同樣,只是輸入的數據不一致。
在實際任務中,咱們並不關心p_g(x';\theta)pg(x′;θ),僅關注生成器x'=G(z;\beta)x′=G(z;β)的優化。
所以,咱們把算法改編以下:
第一步,咱們固定x'=G(z;\beta)x′=G(z;β),求D^*D∗:
V(x,z;D)=\int q(x)log\,D(x)dx+\int q(z)log(1-D(G(z)))dzV(x,z;D)=∫q(x)logD(x)dx+∫q(z)log(1−D(G(z)))dz
D^*=\mathop {argmax}_DV(x,z;D)D∗=argmaxDV(x,z;D)
第二步,咱們固定D*D∗,求G(z;\beta^*)G(z;β∗):
V(x,z,D^*;\beta)=\int q(x)log\,D^*(x)dx+\int q(z)log(1-D^*(G(z;\beta)))dzV(x,z,D∗;β)=∫q(x)logD∗(x)dx+∫q(z)log(1−D∗(G(z;β)))dz
\beta^*=\mathop {argmin}_\beta V(x,z,D^*;\beta^*)β∗=argminβV(x,z,D∗;β∗)
注意,本例的兩個算法都沒有給出嚴格的收斂證實。
因爲求解形式和上例的JSJS散度的參數求解算法很是的一致,咱們能夠期待這種算法可以起做用。爲簡單起見,記爲:
V(G,D)=\int q(x)log\,D(x)dx+\int q(z)log(1-D(G(z)))dzV(G,D)=∫q(x)logD(x)dx+∫q(z)log(1−D(G(z)))dz
G^*=\mathop {argmin}_G(max_D V(G,D))G∗=argminG(maxDV(G,D))
這就是GANGAN生成對抗網絡相關文獻中常見的求解表達方式。
在 Ian J. Goodfellow 的論文 Generative Adversarial Networks 中, 做者先給出了V(G,D)V(G,D)的表達式, 而後再經過JSJS散度的理論來證實其收斂性. 有興趣的讀者能夠參考閱讀。
本文認爲, 若是先介紹JSJS散度, 再進行反推, 能夠更容易的理解GANGAN概念, 理解GANGAN爲何要用這麼複雜的損失函數.
生成對抗網絡的工程實踐
在工程實踐中,咱們遇到的通常是離散的數據,咱們可使用隨機採樣的方法逼近指望值。
首先咱們從前置的隨機分佈p_z(z)pz(z)中取出mm個隨機數z=(z_1,z_2,z_3,\cdots,z_m)z=(z1,z2,z3,⋯,zm), 其次咱們在從真實數據分佈p(x)p(x)中取出mm個真實樣本p=(x_1,x_2,x_3,\cdots,x_m)p=(x1,x2,x3,⋯,xm)。
因爲咱們的數據是隨機選取的,機率越大就越有機會被選中。抽取的樣本就隱含了自身的指望。所以咱們可使用平均數代替上式中的指望,公式改寫以下:
\begin{aligned}V(G,D) &=\int q(x)log\,D(x)dx+\int q(z)log(1-D(G(z)))dz \\ &=\frac{1}{m}\Sigma_{i=1}^m log\,D(x_i)+\frac{1}{m}\Sigma_{i=1}^m log(1-D(G(z_i))) \end{aligned}V(G,D)=∫q(x)logD(x)dx+∫q(z)log(1−D(G(z)))dz=m1Σi=1mlogD(xi)+m1Σi=1mlog(1−D(G(zi)))
咱們能夠直接用上式訓練鑑別器D(x)D(x)。
在訓練生成器時,由於前半部分和zz無關,咱們能夠只使用後半部分。
最後,咱們用一張圖來結束(總結這一部分),從數學的角度看GANs的訓練過程:
------------------------------------------------------------高能結束分割線------------------------------------------------------------------------------
真的,前面的數學原理實在是太枯燥了,編寫的過程當中屢次想放棄,可是正值疫情期間,個人狀態是這樣的:
你說不得找點事幹是否是,因而乎,。。。,就有了上面那一段,無論怎麼樣,忘了剛纔這一段吧,讓咱們從新開始~
大白話GANs原理
知乎上有一個很好的解釋:
假設一個城市治安混亂,很快,這個城市裏就會出現無數的小偷。在這些小偷中,有的多是盜竊高手,有的可能毫無技術可言。假如這個城市開始整飭其治安,忽然開展一場打擊犯罪的「運動」,警察們開始恢復城市中的巡邏,很快,一批「學藝不精」的小偷就被捉住了。之因此捉住的是那些沒有技術含量的小偷,是由於警察們的技術也不行了,在捉住一批低端小偷後,城市的治安水平變得怎樣倒還很差說,但很明顯,城市裏小偷們的平均水平已經大大提升了。
警察們開始繼續訓練本身的破案技術,開始抓住那些愈來愈狡猾的小偷。隨着這些職業慣犯們的落網,警察們也練就了特別的本事,他們能很快能從一羣人中發現可疑人員,因而上前盤查,並最終逮捕嫌犯;小偷們的日子也很差過了,由於警察們的水平大大提升,若是還想之前那樣表現得鬼鬼祟祟,那麼很快就會被警察捉住。
爲了不被捕,小偷們努力表現得不那麼「可疑」,而魔高一尺、道高一丈,警察也在不斷提升本身的水平,爭取將小偷和無辜的普通羣衆區分開。隨着警察和小偷之間的這種「交流」與「切磋」,小偷們都變得很是謹慎,他們有着極高的偷竊技巧,表現得跟普通羣衆如出一轍,而警察們都練就了「火眼金睛」,一旦發現可疑人員,就能立刻發現並及時控制——最終,咱們同時獲得了最強的小偷和最強的警察。
大白話GANs訓練過程
類比上面的過程,生成對抗網絡(GANs)由2個重要的部分構成:
生成器(Generator):經過機器生成數據(大部分狀況下是圖像),目的是「騙過」判別器
判別器(Discriminator):判斷這張圖像是真實的仍是機器生成的,目的是找出生成器作的「假數據」
下面詳細介紹一下過程:
第一階段:固定「判別器D」,訓練「生成器G」
咱們使用一個還 OK 判別器,讓一個「生成器G」不斷生成「假數據」,而後給這個「判別器D」去判斷。
一開始,「生成器G」還很弱,因此很容易被揪出來。
可是隨着不斷的訓練,「生成器G」技能不斷提高,最終騙過了「判別器D」。
到了這個時候,「判別器D」基本屬於瞎猜的狀態,判斷是否爲假數據的機率爲50%。
第二階段:固定「生成器G」,訓練「判別器D」
當經過了第一階段,繼續訓練「生成器G」就沒有意義了。這個時候咱們固定「生成器G」,而後開始訓練「判別器D」。
「判別器D」經過不斷訓練,提升了本身的鑑別能力,最終他能夠準確的判斷出全部的假圖片。
到了這個時候,「生成器G」已經沒法騙過「判別器D」。
循環階段一和階段二
經過不斷的循環,「生成器G」和「判別器D」的能力都愈來愈強。
最終咱們獲得了一個效果很是好的「生成器G」,咱們就能夠用它來生成咱們想要的圖片了。
下面的實際應用部分會展現不少「驚豔」的案例。
從這個過程來看,GANs有什麼優缺點呢?
三個優點:
能更好建模數據分佈(圖像更銳利、清晰)
理論上,GANs 能訓練任何一種生成器網絡。其餘的框架須要生成器網絡有一些特定的函數形式,好比輸出層是高斯的。
無需利用馬爾科夫鏈反覆採樣,無需在學習過程當中進行推斷,沒有複雜的變分下界,避開近似計算棘手的機率的難題。
兩個缺陷:
難訓練,不穩定。生成器和判別器之間須要很好的同步,可是在實際訓練中很容易D收斂,G發散。D/G 的訓練須要精心的設計。
模式缺失(Mode Collapse)問題。GANs的學習過程可能出現模式缺失,生成器開始退化,老是生成一樣的樣本點,沒法繼續學習。
讓咱們用MNIST手寫數字數據集探索一個具體的例子,以進一步描述上面的過程,Mnist數據以下圖所示:
咱們從以下結構網絡生成手寫數字:
GAN步驟:
開始時期生成器接收隨機數並返回圖像;
將生成的圖像與實際圖像流一塊兒反饋給判別器;
鑑別器對假圖像和真實圖像進行判別並返回機率;
GAN過程當中,開始時期從隨機數開始,例如:
注:可能剛開始生成的圖像很糟糕,可是通過鑑別器把關,不停的迭代,會獲得一個不錯的結果。
再來看一個知乎上(@陳琛)的一個例子:垃圾郵件分類。
不知道你們有印象沒,垃圾郵件識別,咱們在最開始的教程裏也有提到過,如今從另外一個角度再來看看。
假設有一個叫Gary的營銷人員試圖騙過David的垃圾郵件分類器來發送垃圾郵件。Gary但願能儘量地發送多的垃圾郵件,David但願儘量少的垃圾郵件經過。理想狀況下會達到納什均衡,儘管咱們誰都不想收到垃圾郵件。
想了解納什均衡,能夠參看這篇博客。
在收到郵件後,David能夠查看spam filter的效果並經過」誤報」或」漏報」來懲罰spam filter。
假設Gary經過本身發送給本身能夠驗證他的垃圾郵件哪些經過了,那麼Gary和David就能夠經過混淆矩陣(confusion matrix,名字聽起來高大上,其實就是個表格而已)來評價本身的工做作的如何:
下面是Gary和David獲得的混淆矩陣:
經此以後,Gary和David都知道出了什麼問題,並從錯誤中學習。Gary會基於以前的成功經驗嘗試其餘的方法來生成更好的垃圾郵件。David會看一下spam filter哪裏出錯了並改進過濾機制。
而後不斷地重複這個過程,直到達到某種納什均衡(固然,有可能最終致使模型崩潰,由於某一方找到了完美的假裝方法或者分辨垃圾郵件的方法)。
下面來詳細看一下混淆矩陣的四個象限:
True Positive:郵件是Gary生成的垃圾郵件而且被David斷定爲垃圾郵件。
generator:被抓包,工做作的不夠好,須要優化。
discriminator:當前不須要作什麼。
False Negative:郵件不是垃圾郵件,可是被David斷定爲垃圾郵件。
generator:當前不須要作什麼。
discriminator:工做作的不夠好,須要優化。
False Positive:郵件是垃圾郵件,可是被David斷定爲正常郵件。
generator:當前不須要作什麼。
discriminator:工做作的不夠好,須要優化。
True Negative:郵件不是垃圾郵件,David也斷定是正常郵件。
generator:當前不須要作什麼。
discriminator:當前不須要作什麼。
基於上面討論,圖示Network如何訓練的:
訓練的步驟包括:
取batch的訓練集x,和隨機生成noise z;
計算loss;
使用back propagation更新generator和discriminator;
咱們已經分析好了,在True Positive,False Negative,False Positive狀況下須要更新:
**True Positive:**意味着generator生成的fake數據被抓包,須要對generator進行優化。須要通過參數被固定的discriminator計算loss,更新generator的權重。注意一次只能對兩個網絡中的一個進行參數調整。
**False Negative:**意味着真的訓練集被discriminator錯認爲fake數據。只更新discriminator的權重。
**False Positive:**generator生成的fake數據,被discriminator斷定爲真的訓練集。只對discriminator進行更新。
如何結合上前一節介紹的數學原理?
如今讓咱們用更數學的角度來解釋一下:
咱們有一個已知的real的分佈,generator生成了一個fake的分佈。由於這個兩個分佈不徹底相同,因此他們之間存在KL-divergence,也就是損失函數不爲0。
discriminator同時看到real的分佈和fake的分佈。若是discriminator能分清楚來自generator生成的與來自real分佈的,就會生成loss並反向傳播更新generator的權重。
generator更新完成後,生成的fake數據更符合real的分佈。
可是若是生成的data仍然不夠接近real的分佈,discriminator依然能識別出來了,所以再次對generator進行權重更新。
終於此次discriminator被騙過了,它認爲generator生成的fake數據就是符合real分佈的。這個就對應False Positive的狀況,須要對discriminator進行更新。
Loss反向傳播來更新discriminator的權重。
繼續這個過程,直到generator生成的分佈與real分佈沒法區分時,網絡達到納什均衡。
(Conditional) Synthesis—條件生成
最好玩的好比Text2Image、Image2Text。能夠基於一段文字生成一張圖片,好比這個Multi-Condition GAN(MA-GAN)的text-to-image的例子:
Data Augmentation—數據加強
GAN學習訓練集樣本的分佈,而後進行採樣生成新的樣本,咱們可使用這些樣原本加強訓練集。通常咱們都是經過對原訓練集的圖片進行旋轉和扭曲來進行加強,這裏GAN提供了一種新的方法。
Style Transfer和Manipulation-風格轉換
將一張圖片的style轉移到另一張圖像上,這與neural style transfer很是相似。Neural Style Transfer能夠認爲是把Style Image的風格加入到Content Image裏。由於只有一張Style Image,因此它其實學到的很難徹底是Style的特徵,由於一個畫家的風格很難經過一幅做品就展示出來。GAN可以很好的從多個做品中學習到畫家的真正風格特徵。
第2/3列爲neural style transfer的效果,第5列爲cycleGAN:
能夠看出對背景特別有效,好比對雲的轉換等:
GAN在動物和水果上的效果:
四季變換:
改變照片的景深:
對線稿填充變成真實的物體:
能夠利用風格轉換來渲染圖像,變成遊戲GTA風格的:
將白天變夜晚:
style transfer能夠具體見這個survey。
Image Super-Resolution
即將圖像從低分辨率LR恢復到高分辨率HR:
更多應用,可參見這篇博客。
最後,若是你想了解更多關於GANs發展史及現有的流行模型,可參看這篇博文:生成對抗網絡(GAN)的發展史。
本博客全部內容僅供學習,不爲商用,若有侵權,請聯繫博主謝謝。
[1] GANs數學原理:https://zhuanlan.zhihu.com/p/...
[2] GANs數學原理:https://blog.csdn.net/oBright...
[3] GANs原理:https://www.jianshu.com/p/bc7...
[4] 郵件識別:https://www.zhihu.com/questio...
[5] 更多應用:https://machinelearningmaster...