摘要: 最通俗的GAN網絡介紹!後端
在本教程中,你將瞭解什麼是生成敵對網絡(GAN),而且在整個過程當中不涉及負責的數學細節。以後,你還將學習如何編寫一個能夠建立數字的簡單GAN!網絡
理解GAN的最簡單方法是經過一個簡單的比喻:dom
假設有一家商店它們從顧客那裏購買某些種類的葡萄酒,用於之後再銷售。ide
然而,有些惡意的顧客爲了得到金錢而出售假酒。在這種狀況下,店主必須可以區分假酒和正品葡萄酒。函數
你能夠想象,最初,僞造者在嘗試出售假酒時可能會犯不少錯誤,而且店主很容易認定該酒不是真的。因爲這些失敗,僞造者會繼續嘗試使用不一樣的技術來模擬真正的葡萄酒,最終纔有可能成功。如今,僞造者知道某些技術已經超過了店主的認識假酒的能力,他能夠開始進一步生產基於這些技術的假酒。學習
同時,店主可能會從其餘店主或葡萄酒專家那裏獲得一些反饋,說明他擁有的一些葡萄酒不是原裝的。這意味着店主必須改善他是如何肯定葡萄酒是僞造的仍是真實的。僞造者的目標是製造與真實葡萄酒沒法區分的葡萄酒,而店主的目標是準確地分辨葡萄酒是否真實。優化
這種來回的競爭博弈就是GAN網絡背後的主要思想。ui
用上面的例子,咱們能夠想出一個GAN的體系結構。阿里雲
GAN網絡中有兩個主要組件:生成器和鑑別器。這個例子中的店主被稱爲鑑別器網絡,而且一般是卷積神經網絡(由於GAN主要用於圖像任務),其主要功能是判斷圖像是真實的機率。編碼
僞造者被稱爲生成網絡,而且一般也是卷積神經網絡(具備解卷積層)。該網絡須要一些噪聲矢量並輸出圖像。在訓練生成網絡時,它會學習圖像的哪些區域進行改進/更改,以便鑑別器將難以將其生成的圖像與真實圖像區分開來。
生成網絡不斷生成更接近真實圖像的圖像,而辨別網絡試圖肯定真實圖像和假圖像之間的差別。最終的目標是創建一個可生成與真實圖像沒法區分的圖像的生成網絡。
如今你已經瞭解了GAN是什麼以及它們的主要組成部分,如今咱們能夠開始編寫一個很是簡單的代碼。本教程將使用Keras,若是你不熟悉此Python庫,則應在繼續以前閱讀翻譯小組其餘文章。本教程是基於這裏開發的很是酷且易於理解的GAN。
你須要作的第一件事是經過如下方式安裝如下軟件包pip:
- keras - matplotlib - tensorflow - tqdm
你將matplotlib用於繪製tensorflow——Keras後端庫,並用tqdm爲每一個時期(迭代)顯示一個奇特的進度條。
下一步是建立一個Python腳本。在這個腳本中,你首先須要導入你將要使用的全部模塊和函數,在使用它們時將給出每一個解釋。
import os import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm from keras.layers import Input from keras.models import Model, Sequential from keras.layers.core import Dense, Dropout from keras.layers.advanced_activations import LeakyReLU from keras.datasets import mnist from keras.optimizers import Adam from keras import initializers
你如今想要設置一些變量值:
# Let Keras know that we are using tensorflow as our backend engine os.environ["KERAS_BACKEND"] = "tensorflow" # To make sure that we can reproduce the experiment and get the same results np.random.seed(10) # The dimension of our random noise vector. random_dim = 100
在開始構建鑑別器和生成器以前,你應該首先收集並預處理數據。你將使用如今最流行的MNIST數據集,該數據集具備一組從0到9範圍內的單個數字的圖像。
def load_minst_data(): # load the data (x_train, y_train), (x_test, y_test) = mnist.load_data() # normalize our inputs to be in the range[-1, 1] x_train = (x_train.astype(np.float32) - 127.5)/127.5 # convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have # 784 columns per row x_train = x_train.reshape(60000, 784) return (x_train, y_train, x_test, y_test)
請注意,mnist.load_data()這個函數是Keras的一部分,它容許你輕鬆將MNIST數據集導入你的工做區。
如今,你能夠建立你的生成器和鑑別器網絡。你能夠爲這兩個網絡使用Adam優化器。對於生成器和鑑別器,你將建立一個帶有三個隱藏層的神經網絡,激活函數爲Leaky Relu。你還應該爲鑑別器添加Drop-out圖層,以提升其對未見圖像的魯棒性。
def get_optimizer(): return Adam(lr=0.0002, beta_1=0.5) def get_generator(optimizer): generator = Sequential() generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02))) generator.add(LeakyReLU(0.2)) generator.add(Dense(512)) generator.add(LeakyReLU(0.2)) generator.add(Dense(1024)) generator.add(LeakyReLU(0.2)) generator.add(Dense(784, activation='tanh')) generator.compile(loss='binary_crossentropy', optimizer=optimizer) return generator def get_discriminator(optimizer): discriminator = Sequential() discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02))) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(512)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(256)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(1, activation='sigmoid')) discriminator.compile(loss='binary_crossentropy', optimizer=optimizer) return discriminator
終於到了將生成器和鑑別器放在一塊兒的時候了!
def get_gan_network(discriminator, random_dim, generator, optimizer): # We initially set trainable to False since we only want to train either the # generator or discriminator at a time discriminator.trainable = False # gan input (noise) will be 100-dimensional vectors gan_input = Input(shape=(random_dim,)) # the output of the generator (an image) x = generator(gan_input) # get the output of the discriminator (probability if the image is real or not) gan_output = discriminator(x) gan = Model(inputs=gan_input, outputs=gan_output) gan.compile(loss='binary_crossentropy', optimizer=optimizer) return gan
爲了保持整個過程的完整性,你能夠建立一個功能,每20個紀元保存你生成的圖像。因爲這不是本教程的核心,因此你不須要徹底理解該功能。
def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)): noise = np.random.normal(0, 1, size=[examples, random_dim]) generated_images = generator.predict(noise) generated_images = generated_images.reshape(examples, 28, 28) plt.figure(figsize=figsize) for i in range(generated_images.shape[0]): plt.subplot(dim[0], dim[1], i+1) plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r') plt.axis('off') plt.tight_layout() plt.savefig('gan_generated_image_epoch_%d.png' % epoch)
你如今已經編碼了大部分網絡,剩下的就是訓練這個網絡,並看看你建立的圖像。
def train(epochs=1, batch_size=128): # Get the training and testing data x_train, y_train, x_test, y_test = load_minst_data() # Split the training data into batches of size 128 batch_count = x_train.shape[0] / batch_size # Build our GAN netowrk adam = get_optimizer() generator = get_generator(adam) discriminator = get_discriminator(adam) gan = get_gan_network(discriminator, random_dim, generator, adam) for e in xrange(1, epochs+1): print '-'*15, 'Epoch %d' % e, '-'*15 for _ in tqdm(xrange(batch_count)): # Get a random set of input noise and images noise = np.random.normal(0, 1, size=[batch_size, random_dim]) image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)] # Generate fake MNIST images generated_images = generator.predict(noise) X = np.concatenate([image_batch, generated_images]) # Labels for generated and real data y_dis = np.zeros(2*batch_size) # One-sided label smoothing y_dis[:batch_size] = 0.9 # Train discriminator discriminator.trainable = True discriminator.train_on_batch(X, y_dis) # Train generator noise = np.random.normal(0, 1, size=[batch_size, random_dim]) y_gen = np.ones(batch_size) discriminator.trainable = False gan.train_on_batch(noise, y_gen) if e == 1 or e % 20 == 0: plot_generated_images(e, generator) if __name__ == '__main__': train(400, 128)
訓練400個紀元後,你能夠查看生成的圖像。查看第一個紀元後產生的圖像,能夠看到它沒有任何真實的結構,在40個紀元後查看圖像,數字開始成形,最後,400個紀元後產生的圖像顯示出清晰的數字,儘管是一對夫婦仍然沒法辨認。
1紀元(左)後的結果40個紀元後(中)的結果400個時代後的結果(右)
此代碼在CPU上每一個紀元大約須要2分鐘,這是選擇此代碼的主要緣由。你能夠嘗試使用更多的紀元,並經過向生成器和鑑別器添加更多(和不一樣的)圖層。可是,當使用更復雜和更深的體系結構時,若是僅使用CPU,則運行時也會增長。
恭喜,你已經完成了本教程的最後部分,你已經以直觀的方式學習生成敵對網絡(GAN)的基礎知識!
本文由@阿里云云棲社區組織翻譯。
文章原標題《demystifying-generative-adversarial-networks》,
譯者:虎說八道,審校:袁虎。
本文爲雲棲社區原創內容,未經容許不得轉載