白話生成對抗網絡 GAN,50 行代碼玩轉 GAN 模型!【附源碼】

紅色石頭的我的網站:redstonewill.comgit

今天,紅色石頭帶你們一塊兒來了解一下現在很是火熱的深度學習模型:生成對抗網絡(Generate Adversarial Network,GAN)。GAN 很是有趣,我就以最直白的語言來說解它,最後實現一個簡單的 GAN 程序來幫助你們加深理解。github

1. 什麼是 GAN?

好了,GAN 如此強大,那它究竟是一個什麼樣的模型結構呢?咱們以前學習過的機器學習或者神經網絡模型主要能作兩件事:預測和分類,這也是咱們所熟知的。那麼是否可讓機器模型自動來生成一張圖片、一段語音?並且能夠經過調整不一樣模型輸入向量來得到特定的圖片和聲音。例如,能夠調整輸入參數,得到一張紅頭髮、藍眼睛的人臉,能夠調整輸入參數,獲得女性的聲音片斷,等等。也就是說,這樣的機器模型可以根據需求,自動生成咱們想要的東西。所以,GAN 應運而生!算法

GAN,即生成對抗網絡,主要包含兩個模塊:生成器(Generative Model)和判別器(Discriminative Model)。生成模型和判別模型之間互相博弈、學習產生至關好的輸出。以圖片爲例,生成器的主要任務是學習真實圖片集,從而使得本身生成的圖片更接近於真實圖片,以「騙過」判別器。而判別器的主要任務是找出出生成器生成的圖片,區分其與真實圖片的不一樣,進行真假判別。在整個迭代過程當中,生成器不斷努力讓生成的圖片愈來愈像真的,而判別器不斷努力識別出圖片的真假。這相似生成器與判別器之間的博弈,隨着反覆迭代,最終兩者達到了平衡:生成器生成的圖片很是接近於真實圖片,而判別器已經很難識別出真假圖片的不一樣了。其表現是對於真假圖片,判別器的機率輸出都接近 0.5。bash

對 GAN 的概念仍是有點不清楚?不要緊,舉個生動的例子來講明。網絡

最近,紅色石頭想學習繪畫,是由於看到梵大師的畫做,也想畫出相似的做品。梵大師的畫做像這樣:app

說畫就畫,紅色石頭找來一個研究梵大師做品不少年的王教授來指導我。王教授經驗豐富,眼光犀利,市面上模仿梵大師的畫做都難逃他的法眼。王教授跟我說了一句話:何時你的畫這幅畫能騙過我,你就算是成功了。dom

紅色石頭很激動,立馬給王教授畫了這幅畫:機器學習

王教授輕輕掃了一眼,滿臉黑線,氣的直哆嗦,「0 分!這也叫畫?差得太多了!」 聽了王教授的話,紅色石頭自我檢討,確實畫的不咋地,連眼睛、鼻子都沒有。因而,又 從新畫了一幅:ide

王教授一看,不到 2 秒鐘,就丟下四個字:1 分!重畫!紅色石頭一想,仍是不行,畫得太差了,就回去好好研究梵大師的畫做風格,不斷改進,從新創做,直到有一天,紅色石頭拿着新的畫做給王教授看:函數

王教授看了一看,說有點像了。我得仔細看看。最後,仍是跟我說,不行不行,細節太差!繼續從新畫吧。唉,王教授愈來愈嚴格了!紅色石頭嘆了口氣回去繼續研究,最後將自我很滿意的一幅畫交給了王教授鑑賞:

這下,王教授戴着眼鏡,仔細品析,許久以後,王教授拍着個人肩膀說,畫得很好,我已經識別不了真假了。哈哈,獲得了王教授的誇獎和確定,內心美滋滋,終於能夠創做出梵大師樣的繪畫做品了。下一步考慮轉行去。

好了,例子說完了(接受你們對我繪畫天賦的吐槽)。這個例子,其實就是一個 GAN 訓練的過程。紅色石頭就是生成器,目的就是要輸出一幅畫可以騙過王教授,讓王教授真假難辨!王教授就是判別器,目的就是要識別出紅色石頭的畫做,判斷其爲假的!整個過程就是「生成 — 對抗」的博弈過程,最終,紅色石頭(生成器)輸出一幅「以假亂真」的畫做,連王教授(判別器)都難以區分了。

這就是 GAN,懂了吧。

2. GAN 模型基本結構

在認識 GAN 模型以前,咱們先來看一看 Yann LeCun 對將來深度學習重大突破技術點的我的見解:

The most important one, in my opinion, is adversarial training (also called GAN for Generative Adversarial Networks). This is an idea that was originally proposed by Ian Goodfellow when he was a student with Yoshua Bengio at the University of Montreal (he since moved to Google Brain and recently to OpenAI).

This, and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion.

Yann LeCun 認爲 GAN 極可能會給深度學習模型帶來新的重大突破,是20年來機器學習領域最酷的想法。這幾年 GAN 發展勢頭很是強勁。下面這張圖是近幾年 ICASSP 會議上全部提交的論文中包含關鍵詞 「generative」、「adversarial」 和 「reinforcement」 的論文數量統計。

數據代表,2018 年,包含關鍵詞 「generative」 和 「adversarial」 的論文數量發生井噴式增加。不難預見, 將來幾年關於 GAN 的論文會更多。

下面來介紹一下 GAN 的基本結構,咱們已經知道了 GAN 由生成器和判別器組成,各用 G 和 D 表示。以生成圖片應用爲例,其模型結構以下所示:

GAN 基本模型由 輸入 Vector、G 網絡、D 網絡組成。其中,G 和 D 通常都是由神經網絡組成。G 的輸出是一幅圖片,只不過是以全鏈接形式。G 的輸出是 D 的輸入,D 的輸入還包含真實樣本集。這樣, D 對真實樣本儘可能輸出 score 高一些,對 G 產生的樣本儘可能輸出 score 低一些。每次循環迭代,G 網絡不斷優化網絡參數,使 D 沒法區分真假;而 D 網絡也在不斷優化網絡參數,提升辨識度,讓真假樣本的 score 有差距。

最終,通過屢次訓練迭代,GAN 模型創建:

最終的 GAN 模型中,G 生成的樣本以假亂真,D 輸出的 score 接近 0.5,即表示真假樣本難以區分,訓練成功。

這裏,重點要講解一下輸入 vector。輸入向量是用來作什麼的呢?其實,輸入 vector 中的每一維度均可以表明輸出圖片的某個特徵。好比說,輸入 vector 的第一個維度數值大小能夠調節生成圖片的頭髮顏色,數值大一些是紅色,數值小一些是黑色;輸入 vector 的第二個維度數值大小能夠調節生成圖片的膚色;輸入 vector 的第三個維度數值大小能夠調節生成圖片的表情情緒,等等。

GAN 的強大之處也正是在於此,經過調節輸入 vector,就能夠生成具備不一樣特徵的圖片。而這些生成的圖片不是真實樣本集裏有的,而是即合理而又沒有見過的圖片。是否是頗有意思呢?下面這張圖反映的是不一樣的 vector 生成不一樣的圖片。

說完了 GAN 的模型以後,咱們再來簡單看一下 GAN 的算法原理。既然有兩個模塊:G 和 D,每一個模塊都有相應的網絡參數。

先來看 D 模塊,它的目標是讓真實樣本 score 越大越好,讓 G 產生的樣本 score 越小越好。那麼能夠獲得 D 的損失函數爲:

其中,x 是真實樣本,G(z) 是 G 生成樣本。咱們但願 D(x) 越大越好,D(G(z)) 越小越好,也就是但願 -D(x) 越小越好,-log(1-D(G(z))) 越小越好。從損失函數的角度來講,可以獲得上式。

再來看 G 模塊,它的目標就是但願其生成的模型可以在 D 中獲得越高的分數越好。那麼能夠獲得 G 的損失函數爲:

知道了損失函數以後,接下來就可使用各類優化算法來訓練模型了。

3. 動手寫個 GAN 模型

接下來,我將使用 PyTorch 實現一個簡單的 GAN 模型。仍然以繪畫創做爲例,假設咱們要創造以下「名畫」(以正弦圖形爲例):

生成該「藝術畫做」的代碼以下:

def artist_works():    # painting from the famous artist (real target)
   r = 0.02 * np.random.randn(1, ART_COMPONENTS)
   paintings = np.sin(PAINT_POINTS * np.pi) + r
   paintings = torch.from_numpy(paintings).float()
   return paintings
複製代碼

而後,分別定義 G 網絡和 D 網絡模型:

G = nn.Sequential(                  # Generator
   nn.Linear(N_IDEAS, 128),        # random ideas (could from normal distribution)
   nn.ReLU(),
   nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
)

D = nn.Sequential(                  # Discriminator
   nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
   nn.ReLU(),
   nn.Linear(128, 1),
   nn.Sigmoid(),                   # tell the probability that the art work is made by artist
)
複製代碼

咱們設置 Adam 算法進行優化:

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
複製代碼

最後,構建 GAN 迭代訓練過程:

plt.ion()    # something about continuous plotting

D_loss_history = []
G_loss_history = []
for step in range(10000):
   artist_paintings = artist_works()          # real painting from artist
   G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas
   G_paintings = G(G_ideas)                   # fake painting from G (random ideas)
   
   prob_artist0 = D(artist_paintings)         # D try to increase this prob
   prob_artist1 = D(G_paintings)              # D try to reduce this prob
   
   D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
   G_loss = torch.mean(torch.log(1. - prob_artist1))
   
   D_loss_history.append(D_loss)
   G_loss_history.append(G_loss)
   
   opt_D.zero_grad()
   D_loss.backward(retain_graph=True)    # reusing computational graph
   opt_D.step()
   
   opt_G.zero_grad()
   G_loss.backward()
   opt_G.step()
   
   if step % 50 == 0:  # plotting
       plt.cla()
       plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
       plt.plot(PAINT_POINTS[0], np.sin(PAINT_POINTS[0] * np.pi), c='#74BCFF', lw=3, label='standard curve')
       plt.text(-1, 0.75, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 8})
       plt.text(-1, 0.5, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 8})
       plt.ylim((-1, 1));plt.legend(loc='lower right', fontsize=10);plt.draw();plt.pause(0.01)

plt.ioff()
plt.show()
複製代碼

我採用了動態繪圖的方式,便於時刻觀察 GAN 模型訓練狀況。

迭代次數爲 1 時:

迭代次數爲 200 時:

迭代次數爲 1000 時:

迭代次數爲 10000 時:

完美!通過 10000 次迭代訓練以後,生成的曲線已經與標準曲線很是接近了。D 的 score 也如預期接近 0.5。

完整代碼有 .py 和 .ipynb 兩種版本,我已經放在了 GitHub 上,須要的請點擊下面的連接獲取。

GitHub-GAN


相關文章
相關標籤/搜索