GAN網絡通俗解釋(圖畫版)

摘要: 最通俗的GAN網絡介紹!後端

在本教程中,你將瞭解什麼是生成敵對網絡(GAN),而且在整個過程當中不涉及負責的數學細節。以後,你還將學習如何編寫一個能夠建立數字的簡單GAN!網絡

什麼是GAN(插畫版介紹)

理解GAN的最簡單方法是經過一個簡單的比喻:dom

假設有一家商店它們從顧客那裏購買某些種類的葡萄酒,用於之後再銷售。ide

clipboard.png

然而,有些惡意的顧客爲了得到金錢而出售假酒。在這種狀況下,店主必須可以區分假酒和正品葡萄酒。函數

clipboard.png

你能夠想象,最初,僞造者在嘗試出售假酒時可能會犯不少錯誤,而且店主很容易認定該酒不是真的。因爲這些失敗,僞造者會繼續嘗試使用不一樣的技術來模擬真正的葡萄酒,最終纔有可能成功。如今,僞造者知道某些技術已經超過了店主的認識假酒的能力,他能夠開始進一步生產基於這些技術的假酒。學習

同時,店主可能會從其餘店主或葡萄酒專家那裏獲得一些反饋,說明他擁有的一些葡萄酒不是原裝的。這意味着店主必須改善他是如何肯定葡萄酒是僞造的仍是真實的。僞造者的目標是製造與真實葡萄酒沒法區分的葡萄酒,而店主的目標是準確地分辨葡萄酒是否真實。優化

這種來回的競爭博弈就是GAN網絡背後的主要思想。ui

生成敵對網絡的組成部分

用上面的例子,咱們能夠想出一個GAN的體系結構。阿里雲

clipboard.png

GAN網絡中有兩個主要組件:生成器和鑑別器。這個例子中的店主被稱爲鑑別器網絡,而且一般是卷積神經網絡(由於GAN主要用於圖像任務),其主要功能是判斷圖像是真實的機率。編碼

僞造者被稱爲生成網絡,而且一般也是卷積神經網絡(具備解卷積層)。該網絡須要一些噪聲矢量並輸出圖像。在訓練生成網絡時,它會學習圖像的哪些區域進行改進/更改,以便鑑別器將難以將其生成的圖像與真實圖像區分開來。

生成網絡不斷生成更接近真實圖像的圖像,而辨別網絡試圖肯定真實圖像和假圖像之間的差別。最終的目標是創建一個可生成與真實圖像沒法區分的圖像的生成網絡。

一個簡單的Keras生成對抗網絡

如今你已經瞭解了GAN是什麼以及它們的主要組成部分,如今咱們能夠開始編寫一個很是簡單的代碼。本教程將使用Keras,若是你不熟悉此Python庫,則應在繼續以前閱讀翻譯小組其餘文章。本教程是基於這裏開發的很是酷且易於理解的GAN。

你須要作的第一件事是經過如下方式安裝如下軟件包pip:

- keras
- matplotlib
- tensorflow
- tqdm

你將matplotlib用於繪製tensorflow——Keras後端庫,並用tqdm爲每一個時期(迭代)顯示一個奇特的進度條。

下一步是建立一個Python腳本。在這個腳本中,你首先須要導入你將要使用的全部模塊和函數,在使用它們時將給出每一個解釋。

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers

你如今想要設置一些變量值:

# Let Keras know that we are using tensorflow as our backend engine
os.environ["KERAS_BACKEND"] = "tensorflow"
# To make sure that we can reproduce the experiment and get the same results
np.random.seed(10)
# The dimension of our random noise vector.
random_dim = 100

在開始構建鑑別器和生成器以前,你應該首先收集並預處理數據。你將使用如今最流行的MNIST數據集,該數據集具備一組從0到9範圍內的單個數字的圖像。

clipboard.png

def load_minst_data(): # load the data (x_train, y_train), (x_test, y_test) = mnist.load_data() # normalize our inputs to be in the range[-1, 1] x_train = (x_train.astype(np.float32) - 127.5)/127.5 # convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have # 784 columns per row x_train = x_train.reshape(60000, 784) return (x_train, y_train, x_test, y_test)

請注意,mnist.load_data()這個函數是Keras的一部分,它容許你輕鬆將MNIST數據集導入你的工做區。

如今,你能夠建立你的生成器和鑑別器網絡。你能夠爲這兩個網絡使用Adam優化器。對於生成器和鑑別器,你將建立一個帶有三個隱藏層的神經網絡,激活函數爲Leaky Relu。你還應該爲鑑別器添加Drop-out圖層,以提升其對未見圖像的魯棒性。

def get_optimizer():
    return Adam(lr=0.0002, beta_1=0.5) 
def get_generator(optimizer):
    generator = Sequential()
    generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(512))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(1024))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(784, activation='tanh'))
    generator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return generator
def get_discriminator(optimizer):
    discriminator = Sequential()
    discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
    discriminator.add(Dense(512))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
    discriminator.add(Dense(256))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
    discriminator.add(Dense(1, activation='sigmoid'))
    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return discriminator

終於到了將生成器和鑑別器放在一塊兒的時候了!

def get_gan_network(discriminator, random_dim, generator, optimizer):
    # We initially set trainable to False since we only want to train either the 
    # generator or discriminator at a time
    discriminator.trainable = False
    # gan input (noise) will be 100-dimensional vectors
    gan_input = Input(shape=(random_dim,))
    # the output of the generator (an image)
    x = generator(gan_input)
    # get the output of the discriminator (probability if the image is real or not)
    gan_output = discriminator(x)
    gan = Model(inputs=gan_input, outputs=gan_output)
    gan.compile(loss='binary_crossentropy', optimizer=optimizer)
    return gan

爲了保持整個過程的完整性,你能夠建立一個功能,每20個紀元保存你生成的圖像。因爲這不是本教程的核心,因此你不須要徹底理解該功能。

def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, random_dim])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(examples, 28, 28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('gan_generated_image_epoch_%d.png' % epoch)

你如今已經編碼了大部分網絡,剩下的就是訓練這個網絡,並看看你建立的圖像。

def train(epochs=1, batch_size=128):
    # Get the training and testing data
    x_train, y_train, x_test, y_test = load_minst_data()
    # Split the training data into batches of size 128
    batch_count = x_train.shape[0] / batch_size
    # Build our GAN netowrk
    adam = get_optimizer()
    generator = get_generator(adam)
    discriminator = get_discriminator(adam)
    gan = get_gan_network(discriminator, random_dim, generator, adam)

    for e in xrange(1, epochs+1):
        print '-'*15, 'Epoch %d' % e, '-'*15
        for _ in tqdm(xrange(batch_count)):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]

            # Generate fake MNIST images
            generated_images = generator.predict(noise)
            X = np.concatenate([image_batch, generated_images])
            # Labels for generated and real data
            y_dis = np.zeros(2*batch_size)
            # One-sided label smoothing
            y_dis[:batch_size] = 0.9
            # Train discriminator
            discriminator.trainable = True
            discriminator.train_on_batch(X, y_dis)
            # Train generator
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            y_gen = np.ones(batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y_gen)

        if e == 1 or e % 20 == 0:
            plot_generated_images(e, generator)

if __name__ == '__main__':
    train(400, 128)

訓練400個紀元後,你能夠查看生成的圖像。查看第一個紀元後產生的圖像,能夠看到它沒有任何真實的結構,在40個紀元後查看圖像,數字開始成形,最後,400個紀元後產生的圖像顯示出清晰的數字,儘管是一對夫婦仍然沒法辨認。

clipboard.png

1紀元(左)後的結果40個紀元後(中)的結果400個時代後的結果(右)

此代碼在CPU上每一個紀元大約須要2分鐘,這是選擇此代碼的主要緣由。你能夠嘗試使用更多的紀元,並經過向生成器和鑑別器添加更多(和不一樣的)圖層。可是,當使用更復雜和更深的體系結構時,若是僅使用CPU,則運行時也會增長。

結論

恭喜,你已經完成了本教程的最後部分,你已經以直觀的方式學習生成敵對網絡(GAN)的基礎知識!

本文由@阿里云云棲社區組織翻譯。

文章原標題《demystifying-generative-adversarial-networks》,

譯者:虎說八道,審校:袁虎。

原文連接

本文爲雲棲社區原創內容,未經容許不得轉載

相關文章
相關標籤/搜索