GANgit
今天讓咱們從這幾方面來探索:
GAN能用來作什麼
GAN的原理
GAN的代碼實現網絡
用途ide
GAN自2014年誕生以來, 就一直備受關注, 著名的應用也隨即產出, 好比比較著名的GAN的應用有Pix2Pix,CycleGAN等, 你們也將它用於各個地方。函數
我以爲還有一個比較重要的用途, 不少人都會缺乏數據集, 那麼就經過GAN去生成數據集了, 經過調節部分參數來進行數據集的產生的類似度。學習
原理編碼
GAN的基本原理其實很是簡單,這裏以生成圖片爲例進行說明。假設咱們有兩個網絡,G(Generator) 和 D(Discriminator)。正如它的名字所暗示的那樣, 它們的功能分別是:人工智能
G是一個生成圖片的網絡, 它接收一個隨機的噪聲(隨機生成的圖片)z, 經過這個噪聲生成圖片,記作G(z). D是一個判別網絡, 判別一張圖片是否是「真實的」。它的輸入參數是x, x表明一張圖片,輸出D(x)表明x爲真實圖片的機率,若是爲>0.5,就表明是真實(類似)的圖片,反之,就表明不是真實的圖片。spa
咱們經過一個假產品宣傳的例子來理解:code
首先, 咱們來定義一下角色:orm
'專家'的手裏面拿着一堆高仿的產品, 正在進行宣講, 咱們是熟知真品的相關信息的, 經過去對比兩個產品之間的差距, 來判斷是贗品的可能性.
這時, 咱們就能夠引出來一個概念, 若是'專家'團隊比較厲害, 完美的仿造了咱們的判斷依據, 好比說產出方, 發明日期, 說明文等等, 那麼咱們就會以爲他是真的, 那麼他就是一個好的生成網絡, 反之, 咱們會判斷他是贗品.
從咱們(判別網絡)出發, 咱們的判斷條件越苛刻, 贗品和真品之間的差距會愈來愈小, 這樣的最後的產出就是真假難分, 徹底被模仿了.
相關資源
深層的原理推薦你們能夠去閱讀Generative Adversarial Networks這篇論文 損失函數等相關細節咱們在實現裏介紹。
實現
接下來咱們就以 mnist來實現 GAN吧.
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./Dataset/datasets/MNIST_data', one_hot=False)
咱們經過tensorflow去下載mnist的數據集, 而後加載到內存, one-hot參數決定咱們的label是否要通過編碼(mnist數據集是有10個類別), 可是咱們判別網絡是對比真實的和生成的之間的區別以及類似的可能性, 因此不須要執行one-hot編碼了.
這裏讀取出來的圖片已經歸一化到[0, 1]之間了.
def show_images(images):
images = np.reshape(images, [images.shape[0], -1]) # images reshape to (batch_size, D) sqrtn = int(np.ceil(np.sqrt(images.shape[0]))) sqrtimg = int(np.ceil(np.sqrt(images.shape[1]))) fig = plt.figure(figsize=(sqrtn, sqrtn)) gs = gridspec.GridSpec(sqrtn, sqrtn) gs.update(wspace=0.05, hspace=0.05) for i, img in enumerate(images): ax = plt.subplot(gs[i]) plt.axis('off') ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect('equal') plt.imshow(img.reshape([sqrtimg,sqrtimg]))
這裏有一個小問題, 若是是在Notebook中執行, 記得加上這句話, 不然須要執行兩次纔會繪製.
%matplotlib inline
def preprocess_img(x):
return 2 * x - 1.0
def deprocess_img(x):
return (x + 1.0) / 2.0
生成網絡
def generator(z):
with tf.variable_scope("generator"): fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu) bn1 = tf.layers.batch_normalization(inputs=fc1, training=True) fc2 = tf.layers.dense(inputs=bn1, units=7*7*128, activation=tf.nn.relu) bn2 = tf.layers.batch_normalization(inputs=fc2, training=True) reshaped = tf.reshape(bn2, shape=[-1, 7, 7, 128]) conv_transpose1 = tf.layers.conv2d_transpose(inputs=reshaped, filters=64, kernel_size=4, strides=2, activation=tf.nn.relu, padding='same') bn3 = tf.layers.batch_normalization(inputs=conv_transpose1, training=True) conv_transpose2 = tf.layers.conv2d_transpose(inputs=bn3, filters=1, kernel_size=4, strides=2, activation=tf.nn.tanh, padding='same') img = tf.reshape(conv_transpose2, shape=[-1, 784]) return img
判別網絡
def discriminator(x):
with tf.variable_scope("discriminator"): unflatten = tf.reshape(x, shape=[-1, 28, 28, 1]) conv1 = tf.layers.conv2d(inputs=unflatten, kernel_size=5, strides=1, filters=32 ,activation=leaky_relu) maxpool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=2, strides=2) conv2 = tf.layers.conv2d(inputs=maxpool1, kernel_size=5, strides=1, filters=64,activation=leaky_relu) maxpool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=2, strides=2) flatten = tf.reshape(maxpool2, shape=[-1, 1024]) fc1 = tf.layers.dense(inputs=flatten, units=1024, activation=leaky_relu) logits = tf.layers.dense(inputs=fc1, units=1) return logits
激活函數咱們使用了leaky_relu, 他的代碼實現是
def leaky_relu(x, alpha=0.01):
activation = tf.maximum(x,alpha*x) return activation
它和 relu的區別就是, 小於0的值也會給與一點小的權重進行保留.
def gan_loss(logits_real, logits_fake):
# Target label vector for generator loss and used in discriminator loss. true_labels = tf.ones_like(logits_fake) # DISCRIMINATOR loss has 2 parts: how well it classifies real images and how well it # classifies fake images. real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels) fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels) # Combine and average losses over the batch D_loss = real_image_loss + fake_image_loss D_loss = tf.reduce_mean(D_loss) # GENERATOR is trying to make the discriminator output 1 for all its images. # So we use our target label vector of ones for computing generator loss. G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels) # Average generator loss over the batch. G_loss = tf.reduce_mean(G_loss) return D_loss, G_loss
損失咱們分爲兩部分, 一部分是生成網絡的, 一部分是判別網絡的.
生成網絡的損失定義爲, 生成圖像的類別與真實標籤(全是1)的交叉熵損失。
判別網絡的損失定義爲, 咱們將真實圖片的標籤設置爲1, 生成圖片的標籤設置爲0, 而後由真實圖片的輸出以及生成圖片的輸出的交叉熵損失和.
T: True, G: Generate 生成損失
生成損失
真實圖片損失
總損失
def run_a_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,\
show_every=250, print_every=50, batch_size=128, num_epoch=10): # compute the number of iterations we need max_iter = int(mnist.train.num_examples*num_epoch/batch_size) for it in range(max_iter): # every show often, show a sample result if it % show_every == 0: samples = sess.run(G_sample) fig = show_images(samples[:16])
plt.show()
print() # run a batch of data through the network minibatch,minbatch_y = mnist.train.next_batch(batch_size) _1, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch}) _2, G_loss_curr = sess.run([G_train_step, G_loss]) if it % show_every == 0: print(_1,_2) # print loss every so often. # We want to make sure D_loss doesn't go to 0 if it % print_every == 0: print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr)) print('Final images') samples = sess.run(G_sample) fig = show_images(samples[:16])
plt.show()
這裏就是開始訓練了, 並展現訓練的結果.
通過學習改進:
項目地址
查看源碼(請在PC端打開)
聲明
該文章參考了天雨粟:生成對抗網絡(GAN)之MNIST數據生成。
————————————————————————————————————————————
Mo (網址:http://momodel.cn)是一個支持 Python 的人工智能建模平臺,能幫助你快速開發訓練並部署 AI 應用。