GAN,叫作生成對抗網絡 (Generative Adversarial Network) 。其基本原理是生成器網絡 G(Generator) 和判別器網絡 D(Discriminator) 相互博弈。生成器網絡 G 的主要做用是生成圖片,在輸入一個隨機編碼 (random code) z後,自動的生成假樣本 G(z) 。判別器網絡 D 的主要做用是判斷輸入是否爲真實樣本並提供反饋機制,真樣本則輸出 1 ,反之爲 0 。在兩個網絡相互博弈的過程當中,兩個網絡的能力都愈來愈高:G 生成的圖片愈來愈像真樣本,D 也愈來愈會判斷圖片的真假,而後咱們在最大化 D 的前提下,最小化 D 對 G 的判斷能力,這實際上就是最小最大值問題,或者說二人零和博弈,其目標函數表達式:
html
爲了提升 GAN 的用戶控制能力,人類進行了一些列的探索研究。好比 Pix2Pix 模型採用有條件的使用用戶輸入,使用**成對的數據 (paired data) 進行訓練; CycleGAN 模型使用不成對的數據 (unpaired data) **的就能訓練 。但不管是 Pix2Pix 仍是 CycleGAN ,都是解決了從一個領域到另外一個領域的圖像轉換問題。當有不少領域須要轉換時,對於每個領域轉換,都須要從新訓練一個模型去解決。目前,存在的模型處理多領域圖像生成任務時,學習 k 個領域之間全部映射就必須訓練 k * (k-1) 個生成器。若是訓練一對一的圖像多領域生成任務時,主要會致使兩個問題:前端
上圖中 (a) 模型說明如何訓練 12 個不一樣生成器網絡以達到 4 個不一樣領域圖像之間轉換任務。很明顯每一個生成器不可以充分利用整個訓練數據,只能從 4 個領域中 2 個領域相互學習,這樣就會生成圖片質量很差。而上圖(b)中的模型就能夠解決這些問題,該模型接受多個領域訓練數據,並僅使用一個生成器來學習多領域圖像之間映射關係。根據模型的長相將該模型稱爲星形網絡,外文名就是 StarGAN 。git
上圖是對 StarGAN 的簡單介紹,主要包含判別器 D 和生成器 G 。
(a)D 對真假圖片進行判別,真圖片判真,假圖片判假,真圖片被分類到相應域。
(b)G 接受真圖片和目標域標籤並生成假圖片;
(c)G 在給定原始域標籤的狀況下將假圖片重建爲原始圖片(重構損失);
(d)G 儘量生成與真實圖像沒法區分的圖像,而且經過 D 分類到目標域。github
首先描述 StarGAN 網絡,在一個數據集中進行多領域的圖像轉換任務;而後咱們討論瞭如何使 StarGAN 能合併包含不一樣標籤的數據集以及對其中任意的標籤屬性靈活進行圖像轉換。數組
訓練一個生成器 G ,可以多領域映射。將帶有領域標籤 c 的輸入圖像 x 轉換爲輸出圖像 y,即網絡
。隨機生成目標領域標籤 c 使得 G 可以靈活的轉換輸入圖像,同時使用 D 控制多領域。這樣 D 就在圖像源和域標籤上產生機率分佈,即 。使用對抗損失函數提升生成圖像質量,達到 D 沒法區分出來輸出圖像和生成圖像之間的差異:
app
對於一個輸入圖像 x 和目標分佈標籤 c ,咱們的目標是將 x 轉換爲輸出圖像 y後可以被正確分類爲目標分佈 c 。爲了實現這一目標,咱們在 D 之上添加一個輔助分類器,並在優化 G 和 D 時採用目標域分類損失函數。簡單來講,咱們將這個式子分解爲兩部分:一個真實圖像的分佈分類損失用於約束 D ,一個假圖像的分佈分類損失用於約束 G 。其表達式以下所示:
框架
經過最小化對抗損失和分類損失, G 訓練生成的圖像儘量與真實圖像同樣,而且可以被分類到正確的目標領域。然而,最小化這兩個損失函數不能保證 , 轉換後的圖像中,只改變領域差別的部分, 而保留輸入圖像中的其餘內容 。故對 G 使用循環一致性損失函數 (cycle consistency loss) ,以下:
dom
最終 G 和 D 的損失函數表示以下:
機器學習
爲了 GAN 訓練過程穩定,生成高質量的圖像,論文中採用自定義梯度懲罰來代替對抗偏差損失:
starGAN 的一個重要優點在於它可以同時合併包含不一樣標籤的不一樣數據集,使得其在測試階段可以控制全部的標籤。從多個數據集學習的問題在於標籤信息對每個數據集而言只是部分已知。在 CelebA 和 RaFD 的例子中,前一個數據集包含諸如髮色,性別等信息,但它不包含任何後一個數據集中包含的諸如開心生氣等表情標籤。這會引發問題,由於在將 G(x,c) 重構回輸入圖像 x 時須要完整的標籤信息 c' 。
爲了緩解這一問題,咱們引入了向量掩碼 m,使 StarGAN 模型可以忽略不肯定的標籤,專一於特定數據集提供的明確的已知標籤。在 StarGAN 中咱們使用 n 維的 one-hot 向量來表明 m ,n 表示數據集的數量。除此以外,咱們將標籤的同一版本定義爲一個數組:
利用多數據集訓練 StarGAN 時,咱們使用上面定義的
做爲生成器的輸入。如此,生成器學會忽略非特定的標籤,而專一於指定的標籤。除了輸入標籤 ,此處的生成器與單數據集訓練的生成器網絡結構同樣。另外一方面咱們也擴展判別器的輔助分類器的分類類別到到所屬彙集的全部標籤。最後,咱們將咱們的模型按照多任務學習的方式進行訓練,其中,判別器只將已知標籤相關的分類偏差最小化便可。以 celebA 數據爲例,下載後的數據包括 label 文件和圖像。
(1, '5_o_Clock_Shadow'), (2, 'Arched_Eyebrows'), (3, 'Attractive'), (4, 'Bags_Under_Eyes'), (5, 'Bald'), (6, 'Bangs'), (7, 'Big_Lips'), (8, 'Big_Nose'), (9, 'Black_Hair'), (10, 'Blond_Hair'), (11, 'Blurry'), (12, 'Brown_Hair'), (13, 'Bushy_Eyebrows'), (14, 'Chubby'), (15, 'Double_Chin'), (16, 'Eyeglasses'), (17, 'Goatee'), (18, 'Gray_Hair'), (19, 'Heavy_Makeup'), (20, 'High_Cheekbones'), (21, 'Male'), (22, 'Mouth_Slightly_Open'), (23, 'Mustache'), (24, 'Narrow_Eyes'), (25, 'No_Beard'), (26, 'Oval_Face'), (27, 'Pale_Skin'), (28, 'Pointy_Nose'), (29, 'Receding_Hairline'), (30, 'Rosy_Cheeks'), (31, 'Sideburns'), (32, 'Smiling'), (33, 'Straight_Hair'), (34, 'Wavy_Hair'), (35, 'Wearing_Earrings'), (36, 'Wearing_Hat'), (37, 'Wearing_Lipstick'), (38, 'Wearing_Necklace'), (39, 'Wearing_Necktie'), (40, 'Young')
000001.jpg -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
經過本文學習,您應該初步瞭解 StarGAN 模型的網絡結構和實現原理,以及關鍵部分代碼的初步實現。若是您對深度學習 Tensorflow 比較瞭解,能夠參考 Tensorflow版實現starGAN;若是您對pytorch框架比較熟悉,能夠參考 pytorch實現starGAN;若是您想更深刻的學習瞭解starGAN原理,能夠參考 論文。
若是想體驗項目效果,您能夠登錄 Mo 平臺,在 應用中心 中找到 StarGAN,能夠體驗如下五種特徵['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'] 的風格變換。考慮到代碼較長,咱們在StarGAN 項目源碼中對相關代碼作了詳細解釋。您在學習的過程當中,遇到困難或者發現咱們的錯誤,能夠隨時聯繫咱們。
1.論文:https://arxiv.org/pdf/1711.09020.pdf
2.博客:https://blog.csdn.net/stdcoutzyx/article/details/78829232
3.博客:https://www.cnblogs.com/Thinker-pcw/p/9785379.html
4.pytorch原版github地址:https://github.com/yunjey/StarGAN
5.tensorflow版github地址:https://github.com/taki0112/StarGAN-Tensorflow
6.Celeba數據集:www.dropbox.com/s/d1kjpkqkl…
Mo(網址:momodel.cn)是一個支持 Python 的人工智能在線建模平臺,能幫助你快速開發、訓練並部署模型。
Mo 人工智能俱樂部 是由網站的研發與產品設計團隊發起、致力於下降人工智能開發與使用門檻的俱樂部。團隊具有大數據處理分析、可視化與數據建模經驗,已承擔多領域智能項目,具有從底層到前端的全線設計開發能力。主要研究方向爲大數據管理分析與人工智能技術,並以此來促進數據驅動的科學研究。
目前俱樂部每週六在杭州舉辦以機器學習爲主題的線下技術沙龍活動,不按期進行論文分享與學術交流。但願能匯聚來自各行各業對人工智能感興趣的朋友,不斷交流共同成長,推進人工智能民主化、應用普及化。