送書 | AI插畫師:如何用基於PyTorch的生成對抗網絡生成動漫頭像?

本文由 「AI前線」原創,原文連接: 送書 | AI插畫師:如何用基於PyTorch的生成對抗網絡生成動漫頭像?
做者|陳雲
編輯|Natalie

AI 前線導讀:」2016 年是屬於 TensorFlow 的一年,憑藉谷歌的大力推廣,TensorFlow 佔據了各大媒體的頭條。2017 年年初,PyTorch 的橫空出世吸引了研究人員極大的關注,PyTorch 簡潔優雅的設計、統一易用的接口、追風逐電的速度和變化無方的靈活性給人留下深入的印象。程序員

本文節選自《深度學習框架 PyTorch 入門與實踐》第 7 章,爲讀者講解當前最火爆的生成對抗網絡(GAN),帶領讀者從零開始實現一個動漫頭像生成器,可以利用 GAN 生成風格多變的動漫頭像。注意啦,文末有送書福利!」小程序


生成對抗網絡(Generative Adversarial Net,GAN)是近年來深度學習中一個十分熱門的方向,卷積網絡之父、深度學習元老級人物 LeCun Yan 就曾說過「GAN is the most interesting idea in the last 10 years in machine learning」。尤爲是近兩年,GAN 的論文呈現井噴的趨勢,GitHub 上有人收集了各類各樣的 GAN 變種、應用、研究論文等,其中有名稱的多達數百篇。做者還統計了 GAN 論文發表數目隨時間變化的趨勢,如圖 7-1 所示,足見 GAN 的火爆程度。網絡

圖 7-1 GAN 的論文數目逐月累加圖框架


GAN 的原理簡介dom

GAN 的開山之做是被稱爲「GAN 之父」的 Ian Goodfellow 發表於 2014 年的經典論文 Generative Adversarial Networks ,在這篇論文中他提出了生成對抗網絡,並設計了第一個 GAN 實驗——手寫數字生成。機器學習

GAN 的產生來自於一個靈機一動的想法:ide

「What I cannot create, I do not understand.」(那些我所不能創造的,我也沒有真正地理解它。)
—Richard Feynman

相似地,若是深度學習不能創造圖片,那麼它也沒有真正地理解圖片。當時深度學習已經開始在各種計算機視覺領域中攻城略地,在幾乎全部任務中都取得了突破。可是人們一直對神經網絡的黑盒模型表示質疑,因而愈來愈多的人從可視化的角度探索卷積網絡所學習的特徵和特徵間的組合,而 GAN 則從生成學習角度展現了神經網絡的強大能力。GAN 解決了非監督學習中的著名問題:給定一批樣本,訓練一個系統可以生成相似的新樣本函數

生成對抗網絡的網絡結構如圖 7-2 所示,主要包含如下兩個子網絡。工具

  • 生成器(generator):輸入一個隨機噪聲,生成一張圖片。
  • 判別器(discriminator):判斷輸入的圖片是真圖片仍是假圖片。

圖 7-2 生成對抗網絡結構圖學習

訓練判別器時,須要利用生成器生成的假圖片和來自真實世界的真圖片;訓練生成器時,只用噪聲生成假圖片。判別器用來評估生成的假圖片的質量,促使生成器相應地調整參數。

生成器的目標是儘量地生成以假亂真的圖片,讓判別器覺得這是真的圖片;判別器的目標是將生成器生成的圖片和真實世界的圖片區分開。能夠看出這兩者的目標相反,在訓練過程當中互相對抗,這也是它被稱爲生成對抗網絡的緣由。

上面的描述可能有點抽象,讓咱們用收藏齊白石做品(齊白石做品如圖 7-3 所示)的書畫收藏家和假畫販子的例子來講明。假畫販子至關因而生成器,他們但願可以模仿大師真跡僞造出以假亂真的假畫,騙過收藏家,從而賣出高價;書畫收藏家則但願將贗品和真跡區分開,讓真跡流傳於世,銷燬贗品。這裏假畫販子和收藏家所交易的畫,主要是齊白石畫的蝦。齊白石畫蝦能夠說是畫壇一絕,從來爲世人所追捧。

圖 7-3 齊白石畫蝦圖真跡

在這個例子中,一開始假畫販子和書畫收藏家都是新手,他們對真跡和贗品的概念都很模糊。假畫販子仿造出來的假畫幾乎都是隨機塗鴉,而書畫收藏家的鑑定能力不好,有很多贗品被他當成真跡,也有許多真跡被當成贗品。

首先,書畫收藏家收集了一大堆市面上的贗品和齊白石大師的真跡,仔細研究對比,初步學習了畫中蝦的結構,明白畫中的生物形狀彎曲,而且有一對相似鉗子的「螯足」,對於不符合這個條件的假畫所有過濾掉。當收藏家用這個標準到市場上進行鑑定時,假畫基本沒法騙過收藏家,假畫販子損失慘重。可是假畫販子本身仿造的贗品中,仍是有一些矇騙過關,這些矇騙過關的贗品中都有彎曲的形狀,而且有一對相似鉗子的「螯足」。因而假畫販子開始修改仿造的手法,在仿造的做品中加入彎曲的形狀和一對相似鉗子的「螯足」。除了這些特色,其餘地方例如顏色、線條都是隨機畫的。假畫販子製造出的初版贗品如圖 7-4 所示。

圖 7-4 假畫販子製造的初版贗品

當假畫販子把這些畫拿到市面上去賣時,很容易就騙過了收藏家,由於畫中有一隻彎曲的生物,生物前面有一對相似鉗子的東西,符合收藏家認定的真跡的標準,因此收藏家就把它當成真跡買回來。隨着時間的推移,收藏家買回愈來愈多的假畫,損失慘重,因而他又閉門研究贗品和真跡之間的區別,通過反覆比較對比,他發現齊白石畫蝦的真跡中除了有彎曲的形狀,蝦的觸鬚蔓長,通身做半透明狀,而且畫的蝦的細節十分豐富,蝦的每一節之間均呈白色狀。

收藏家學成以後,從新出山,而假畫販子的仿造技法沒有提高,所製造出來的贗品被收藏家輕鬆識破。因而假畫販子也開始嘗試不一樣的畫蝦手法,大多都是徒勞無功,不過在衆多嘗試之中,仍是有一些贗品騙過了收藏家的眼睛。假畫販子發現這些仿製的贗品觸鬚蔓長,通身做半透明狀,而且畫的蝦的細節十分豐富,如圖 7-5 所示。因而假畫販子開始大量仿造這種畫,並拿到市面上銷售,許多都成功地騙過了收藏家。

圖 7-5 假畫販子製造的第二版贗品

收藏家再度損失慘重,被迫關門研究齊白石的真跡和贗品之間的區別,學習齊白石真跡的特色,提高本身的鑑定能力。就這樣,經過收藏家和假畫販子之間的博弈,收藏家從零開始慢慢提高了本身對真跡和贗品的鑑別能力,而假畫販子也不斷地提升本身仿造齊白石真跡的水平。收藏家利用假畫販子提供的贗品,做爲和真跡的對比,對齊白石畫蝦真跡有了更好的鑑賞能力;而假畫販子也不斷嘗試,提高仿造水平,提高仿造假畫的質量,即便最後製造出來的仍屬於贗品,可是和真跡相比也很接近了。收藏家和假畫販子兩者之間互相博弈對抗,同時又不斷促使着對方學習進步,達到共同提高的目的。

在這個例子中,假畫販子至關於一個生成器,收藏家至關於一個判別器。一開始生成器和判別器的水平都不好,由於兩者都是隨機初始化的。訓練過程分爲兩步交替進行,第一步是訓練判別器(只修改判別器的參數,固定生成器),目標是把真跡和贗品區分開;第二步是訓練生成器(只修改生成器的參數,固定判別器),爲的是生成的假畫可以被判別器判別爲真跡(被收藏家認爲是真跡)。這兩步交替進行,進而分類器和判別器都達到了一個很高的水平。訓練到最後,生成器生成的蝦的圖片(如圖 7-6 所示)和齊白石的真跡幾乎沒有差異。

圖 7-6 生成器生成的蝦

下面咱們來思考網絡結構的設計。判別器的目標是判斷輸入的圖片是真跡仍是贗品,因此能夠當作是一個二分類網絡,參考第 6 章中 Dog vs. Cat 的實驗,咱們能夠設計一個簡單的卷積網絡。生成器的目標是從噪聲中生成一張彩色圖片,這裏咱們採用普遍使用的 DCGAN(Deep Convolutional Generative Adversarial Networks)結構,即採用全卷積網絡,其結構如圖 7-7 所示。網絡的輸入是一個 100 維的噪聲,輸出是一個 3×64×64 的圖片。這裏的輸入能夠當作是一個 100×1×1 的圖片,經過上卷積慢慢增大爲 4×四、8×八、16×1六、32×32 和 64×64。上卷積,或稱轉置卷積,是一種特殊的卷積操做,相似於卷積操做的逆運算。當卷積的 stride 爲 2 時,輸出相比輸入會下采樣到一半的尺寸;而當上卷積的 stride 爲 2 時,輸出會上採樣到輸入的兩倍尺寸。這種上採樣的作法能夠理解爲圖片的信息保存於 100 個向量之中,神經網絡根據這 100 個向量描述的信息,前幾步的上採樣先勾勒出輪廓、色調等基礎信息,後幾步上採樣慢慢完善細節。網絡越深,細節越詳細。

圖 7-7 DCGAN 中生成器網絡結構圖

在 DCGAN 中,判別器的結構和生成器對稱:生成器中採用上採樣的卷積,判別器中就採用下采樣的卷積,生成器是根據噪聲輸出一張 64×64×3 的圖片,而判別器則是根據輸入的 64×64×3 的圖片輸出圖片屬於正負樣本的分數(機率)。


用 GAN 生成動漫頭像

本節將用 GAN 實現一個生成動漫人物頭像的例子。在日本的技術博客網站上 有個博主(估計是一位二次元的愛好者),利用 DCGAN 從 20 萬張動漫頭像中學習,最終可以利用程序自動生成動漫頭像,生成的圖片效果如圖 7-8 所示。源程序是利用 Chainer 框架實現的,本節咱們嘗試利用 PyTorch 實現。

圖 7-8 DCGAN 生成的動漫頭像

原始的圖片是從網站中爬取的,並利用 OpenCV 從中截取頭像,處理起來比較麻煩。這裏咱們使用知乎用戶何之源爬取並通過處理的 5 萬張圖片。能夠從本書配套程序的 README.MD 的百度網盤連接下載全部的圖片壓縮包,並解壓縮到指定的文件夾中。須要注意的是,這裏圖片的分辨率是 3×96×96,而不是論文中的 3×64×64,所以須要相應地調整網絡結構,使生成圖像的尺寸爲 96。

咱們首先來看本實驗的代碼結構。

接着來看 model.py 中是如何定義生成器的。

能夠看出生成器的搭建相對比較簡單,直接使用 nn.Sequential 將上卷積、激活、池化等操做拼接起來便可,這裏須要注意上卷積 ConvTransposed2d 的使用。當 kernel size 爲 四、stride 爲 二、padding 爲 1 時,根據公式 H_out=(H_in-1)*stride-2*padding+kernel_size,輸出尺寸恰好變成輸入的兩倍。最後一層採用 kernel size 爲 五、stride 爲 三、padding 爲 1,是爲了將 32×32 上採樣到 96×96,這是本例中圖片的尺寸,與論文中 64×64 的尺寸不同。最後一層用 Tanh 將輸出圖片的像素歸一化至 -1~1,若是但願歸一化至 0~1,則需使用 Sigmoid。接着咱們來看判別器的網絡結構。

能夠看出判別器和生成器的網絡結構幾乎是對稱的,從卷積核大小到 padding、stride 等設置,幾乎如出一轍。例如生成器的最後一個卷積層的尺度是(5,3,1),判別器的第一個卷積層的尺度也是(5,3,1)。另外,這裏須要注意的是生成器的激活函數用的是 ReLU,而判別器使用的是 LeakyReLU,兩者並沒有本質區別,這裏的選擇更可能是經驗總結。每個樣本通過判別器後,輸出一個 0~1 的數,表示這個樣本是真圖片的機率。在開始寫訓練函數前,先來看看模型的配置參數。

這些只是模型的默認參數,還能夠利用 Fire 等工具經過命令行傳入,覆蓋默認值。另外,咱們也能夠直接使用 opt.attr,還能夠利用 IDE/IPython 提供的自動補全功能,十分方便。這裏的超參數設置大可能是照搬 DCGAN 論文的默認值,做者通過大量實驗,發現這些參數可以更快地訓練出一個不錯的模型。

當咱們下載完數據以後,須要將全部圖片放在一個文件夾,而後將該文件夾移動至 data 目錄下(請確保 data 下沒有其餘的文件夾)。這種處理方式是爲了可以直接使用 torchvision 自帶的 ImageFolder 讀取圖片,而沒必要本身寫 Dataset。數據讀取與加載的代碼以下:

可見,用 ImageFolder 配合 DataLoader 加載圖片十分方便。

在進行訓練以前,咱們還須要定義幾個變量:模型、優化器、噪聲等。


在加載預訓練模型時,最好指定 map_location。由於若是程序以前在 GPU 上運行,那麼模型就會被存成 torch.cuda.Tensor,這樣加載時會默認將數據加載至顯存。若是運行該程序的計算機中沒有 GPU,加載就會報錯,故經過指定 map_location 將 Tensor 默認加載入內存中,待有須要時再移至顯存中。

下面開始訓練網絡,訓練步驟以下。

(1)訓練判別器。

  • 固定生成器
  • 對於真圖片,判別器的輸出機率值儘量接近 1
  • 對於生成器生成的假圖片,判別器儘量輸出 0

(2)訓練生成器。

  • 固定判別器
  • 生成器生成圖片,儘量讓判別器輸出 1

(3)返回第一步,循環交替訓練。

這裏須要注意如下幾點。

  • 訓練生成器時,無須調整判別器的參數;訓練判別器時,無須調整生成器的參數。
  • 在訓練判別器時,須要對生成器生成的圖片用 detach 操做進行計算圖截斷,避免反向傳播將梯度傳到生成器中。由於在訓練判別器時咱們不須要訓練生成器,也就不須要生成器的梯度。
  • 在訓練分類器時,須要反向傳播兩次,一次是但願把真圖片判爲 1,一次是但願把假圖片判爲 0。也能夠將這二者的數據放到一個 batch 中,進行一次前向傳播和一次反向傳播便可。可是人們發現,在一個 batch 中只包含真圖片或只包含假圖片的作法最好。
  • 對於假圖片,在訓練判別器時,咱們但願它輸出爲 0;而在訓練生成器時,咱們但願它輸出爲 1。所以能夠看到一對看似矛盾的代碼:error_d_fake = criterion(fake_output, fake_labels) 和 error_g = criterion(fake_output, true_labels)。其實這也很好理解,判別器但願可以把假圖片判別爲 fake_label,而生成器則但願能把它判別爲 true_label,判別器和生成器互相對抗提高。

接下來就是一些可視化的代碼。每次可視化使用的噪聲都是固定的 fix_noises,由於這樣便於咱們比較對於相同的輸入,生成器生成的圖片是如何一步步提高的。另外,因爲咱們對輸入的圖片進行了歸一化處理(-1~1),在可視化時則須要將它還原成原來的 scale(0~1) 。

除此以外,還提供了一個函數,能加載預訓練好的模型,並利用噪聲隨機生成圖片。

完整的代碼請參考本書的附帶樣例代碼 chapter7/AnimeGAN。參照 README.MD 中的指南配置環境,並準備好數據,然後用以下命令便可開始訓練:

若是使用 visdom 的話,此時打開 http://[your ip]:8097 就能看到生成的圖像。

訓練完成後,咱們能夠利用生成網絡隨機生成動漫頭像,輸入命令以下:


實驗結果分析

實驗結果如圖 7-9 所示,分別是訓練 1 個、10 個、20 個、30 個、40 個、200 個 epoch 以後神經網絡生成的動漫頭像。須要注意的是,每次生成器輸入的噪聲都是同樣的,因此咱們能夠對比在相同的輸入下,生成圖片的質量是如何慢慢改善的。

剛開始生成的圖像比較模糊(1 個 epoch),可是能夠看出圖像已經有面部輪廓。

繼續訓練 10 個 epoch 以後,生成的圖多了不少細節信息,包括頭髮、顏色等,可是整體仍是很模糊。

訓練 20 個 epoch 以後,細節繼續完善,包括頭髮的紋理、眼睛的細節等,但仍是有很多塗抹的痕跡。

訓練到第 40 個 epoch 時,已經能看出明顯的面部輪廓和細節,但仍是有塗抹現象,而且有些細節不夠合理,例如眼睛一大一小,面部的輪廓扭曲嚴重。

當訓練到 200 個 epoch 以後,圖片的細節已經十分完善,線條更流暢,輪廓更清晰,雖然還有一些不合理之處,可是已經有很多圖片可以以假亂真了。

圖 7-9 GAN 生成的動漫頭像

相似的生成動漫頭像的項目還有「用 DRGAN 生成高清的動漫頭像」,效果如圖 7-10 所示。但遺憾的是,因爲論文中使用的數據涉及版權問題,未能公開。這篇論文的主要改進包括使用了更高質量的圖片數據和更深、更復雜的模型。

圖 7-10 用 DRGAN 生成的動漫頭像

本章講解的樣例程序還能夠應用到不一樣的生成圖片場景中,只要將訓練圖片改爲其餘類型的圖片便可,例如 LSUN 客房圖片集、MNIST 手寫數據集或 CIFAR10 數據集等。事實上,上述模型還有很大的改進空間。在這裏,咱們使用的全卷積網絡只有四層,模型比較淺,而在 ResNet 的論文發表以後,也有很多研究者嘗試在 GAN 的網絡結構中引入 Residual Block 結構,並取得了不錯的視覺效果。感興趣的讀者能夠嘗試將示例代碼中的單層卷積修改成 Residual Block,相信能夠取得不錯的效果。

近年來,GAN 的一個重大突破在於理論研究。論文 Towards Principled Methods for Training Generative Adversarial Networks 從理論的角度分析了 GAN 爲什麼難以訓練,做者隨後在另外一篇論文 Wasserstein GAN 中針對性地提出了一個更好的解決方案。可是 Wasserstein GAN 這篇論文在部分技術細節上的實現過於隨意,因此隨後又有人有針對性地提出 Improved Training of Wasserstein GANs,更好地訓練 WGAN。後面兩篇論文分別用 PyTorch 和 TensorFlow 實現,代碼能夠從 GitHub 上搜索到。筆者當初也嘗試用 100 行左右的代碼實現了 Wasserstein GAN,感興趣的讀者能夠去了解 。

隨着 GAN 研究的逐漸成熟,人們也嘗試把 GAN 用於工業實際問題之中,而在衆多相關論文中,最使人印象深入的就是 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks ,論文中提出了一種新的 GAN 結構稱爲 CycleGAN。CycleGAN 利用 GAN 實現風格遷移、黑白圖像彩色化,以及馬和斑馬相互轉化等,效果十分出衆。論文的做者用 PyTorch 實現了全部代碼,並開源在 GitHub 上,感興趣的讀者能夠自行查閱。

本章主要介紹 GAN 的基本原理,並帶領讀者利用 GAN 生成動漫頭像。GAN 有許多變種,GitHub 上有許多利用 PyTorch 實現的各類 GAN,感興趣的讀者能夠自行查閱。

做者介紹

陳雲,Python 程序員、Linux 愛好者和 PyTorch 源碼貢獻者。主要研究方向包括計算機視覺和機器學習。「2017 知乎看山杯機器學習挑戰賽」一等獎,「2017 天池醫療 AI 大賽」第八名。熱衷於推廣 PyTorch,並有豐富的使用經驗,活躍於 PyTorch 論壇和知乎相關板塊。

福利!福利!咱們將給 AI 前線的粉絲送出《深度學習框架 PyTorch 入門與實踐》紙質書籍 10 本!在本文下方留言給出你想要這本書的理由,咱們會邀請你加入贈書羣,本次獲獎名單由抽獎小程序隨機抽取,2 月 6 日(週二)上午 10 點開獎,獲獎者每人得到一本。另附京東購買地址,戳「閱讀原文」!

更多幹貨內容,可關注AI前線,ID:ai-front,後臺回覆「AI」、「TF」、「大數據」可得到《AI前線》系列PDF迷你書和技能圖譜。

相關文章
相關標籤/搜索