機器學習分享——手把手帶你寫一個GAN

GANgit

今天讓咱們從這幾方面來探索:
GAN能用來作什麼
GAN的原理
GAN的代碼實現網絡

用途ide

GAN自2014年誕生以來, 就一直備受關注, 著名的應用也隨即產出, 好比比較著名的GAN的應用有Pix2Pix,CycleGAN等, 你們也將它用於各個地方。函數

  1. 缺失/模糊像素的補充
  2. 圖片修復
  3. ……

我以爲還有一個比較重要的用途, 不少人都會缺乏數據集, 那麼就經過GAN去生成數據集了, 經過調節部分參數來進行數據集的產生的類似度。學習

原理編碼

GAN的基本原理其實很是簡單,這裏以生成圖片爲例進行說明。假設咱們有兩個網絡,G(Generator) 和 D(Discriminator)。正如它的名字所暗示的那樣, 它們的功能分別是:人工智能

G是一個生成圖片的網絡, 它接收一個隨機的噪聲(隨機生成的圖片)z, 經過這個噪聲生成圖片,記作G(z). D是一個判別網絡, 判別一張圖片是否是「真實的」。它的輸入參數是x, x表明一張圖片,輸出D(x)表明x爲真實圖片的機率,若是爲>0.5,就表明是真實(類似)的圖片,反之,就表明不是真實的圖片。spa

圖片描述
咱們經過一個假產品宣傳的例子來理解:code

首先, 咱們來定義一下角色:orm

  1. 進行宣傳的'專家'(生成網絡)
  2. 正在聽講的'咱們'(判別網絡)

'專家'的手裏面拿着一堆高仿的產品, 正在進行宣講, 咱們是熟知真品的相關信息的, 經過去對比兩個產品之間的差距, 來判斷是贗品的可能性.

這時, 咱們就能夠引出來一個概念, 若是'專家'團隊比較厲害, 完美的仿造了咱們的判斷依據, 好比說產出方, 發明日期, 說明文等等, 那麼咱們就會以爲他是真的, 那麼他就是一個好的生成網絡, 反之, 咱們會判斷他是贗品.

從咱們(判別網絡)出發, 咱們的判斷條件越苛刻, 贗品和真品之間的差距會愈來愈小, 這樣的最後的產出就是真假難分, 徹底被模仿了.

相關資源

深層的原理推薦你們能夠去閱讀Generative Adversarial Networks這篇論文 損失函數等相關細節咱們在實現裏介紹。

實現

接下來咱們就以 mnist來實現 GAN吧.

  1. 首先, 咱們先下載數據集.

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./Dataset/datasets/MNIST_data', one_hot=False)

咱們經過tensorflow去下載mnist的數據集, 而後加載到內存, one-hot參數決定咱們的label是否要通過編碼(mnist數據集是有10個類別), 可是咱們判別網絡是對比真實的和生成的之間的區別以及類似的可能性, 因此不須要執行one-hot編碼了.

這裏讀取出來的圖片已經歸一化到[0, 1]之間了.

  1. 俗話說, 知己知彼, 百戰百勝, 那咱們拿到數據集, 就先來看看它長什麼樣.

def show_images(images):

images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)

for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))

圖片描述

這裏有一個小問題, 若是是在Notebook中執行, 記得加上這句話, 不然須要執行兩次纔會繪製.

%matplotlib inline

  1. 數據看過了, 咱們該對它進行必定的處理了, 這裏咱們只是將數據縮放到[-1, 1]之間.

def preprocess_img(x):

return 2 * x - 1.0

def deprocess_img(x):

return (x + 1.0) / 2.0
  1. 數據處理完了, 接下來咱們要開始搭建模型了, 這一部分咱們有兩個模型, 一個生成網絡, 一個判別網絡.

生成網絡

def generator(z):

with tf.variable_scope("generator"):

    fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu)
    bn1 = tf.layers.batch_normalization(inputs=fc1, training=True)
    fc2 = tf.layers.dense(inputs=bn1, units=7*7*128, activation=tf.nn.relu)
    bn2 = tf.layers.batch_normalization(inputs=fc2, training=True)
    reshaped = tf.reshape(bn2, shape=[-1, 7, 7, 128])
    conv_transpose1 = tf.layers.conv2d_transpose(inputs=reshaped, filters=64, kernel_size=4, strides=2, activation=tf.nn.relu,
                                                padding='same')
    bn3 = tf.layers.batch_normalization(inputs=conv_transpose1, training=True)
    conv_transpose2 = tf.layers.conv2d_transpose(inputs=bn3, filters=1, kernel_size=4, strides=2, activation=tf.nn.tanh,
                                    padding='same')

    img = tf.reshape(conv_transpose2, shape=[-1, 784])
    return img

判別網絡

def discriminator(x):

with tf.variable_scope("discriminator"):

    unflatten = tf.reshape(x, shape=[-1, 28, 28, 1])
    conv1 = tf.layers.conv2d(inputs=unflatten, kernel_size=5, strides=1, filters=32 ,activation=leaky_relu)
    maxpool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=2, strides=2)
    conv2 = tf.layers.conv2d(inputs=maxpool1, kernel_size=5, strides=1, filters=64,activation=leaky_relu)
    maxpool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=2, strides=2)
    flatten = tf.reshape(maxpool2, shape=[-1, 1024])
    fc1 = tf.layers.dense(inputs=flatten, units=1024, activation=leaky_relu)
    logits = tf.layers.dense(inputs=fc1, units=1)

    return logits

激活函數咱們使用了leaky_relu, 他的代碼實現是

def leaky_relu(x, alpha=0.01):

activation = tf.maximum(x,alpha*x)
return activation

它和 relu的區別就是, 小於0的值也會給與一點小的權重進行保留.

  1. 創建損失函數

def gan_loss(logits_real, logits_fake):

# Target label vector for generator loss and used in discriminator loss.
true_labels = tf.ones_like(logits_fake)

# DISCRIMINATOR loss has 2 parts: how well it classifies real images and how well it
# classifies fake images.
real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels)
fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels)

# Combine and average losses over the batch
D_loss = real_image_loss + fake_image_loss
D_loss = tf.reduce_mean(D_loss)

# GENERATOR is trying to make the discriminator output 1 for all its images.
# So we use our target label vector of ones for computing generator loss.
G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels)

# Average generator loss over the batch.
G_loss = tf.reduce_mean(G_loss)

return D_loss, G_loss

損失咱們分爲兩部分, 一部分是生成網絡的, 一部分是判別網絡的.

生成網絡的損失定義爲, 生成圖像的類別與真實標籤(全是1)的交叉熵損失。
圖片描述

判別網絡的損失定義爲, 咱們將真實圖片的標籤設置爲1, 生成圖片的標籤設置爲0, 而後由真實圖片的輸出以及生成圖片的輸出的交叉熵損失和.

T: True, G: Generate 生成損失

生成損失

圖片描述
真實圖片損失

圖片描述
總損失
圖片描述

  1. 訓練

def run_a_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,\

show_every=250, print_every=50, batch_size=128, num_epoch=10):
# compute the number of iterations we need
max_iter = int(mnist.train.num_examples*num_epoch/batch_size)
for it in range(max_iter):
    # every show often, show a sample result
    if it % show_every == 0:
        samples = sess.run(G_sample)
        fig = show_images(samples[:16])

plt.show()

print()
    # run a batch of data through the network
    minibatch,minbatch_y = mnist.train.next_batch(batch_size)
    _1, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch})
    _2, G_loss_curr = sess.run([G_train_step, G_loss])
    if it % show_every == 0:
        print(_1,_2)
    # print loss every so often.
    # We want to make sure D_loss doesn't go to 0
    if it % print_every == 0:
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr))
print('Final images')
samples = sess.run(G_sample)

fig = show_images(samples[:16])

plt.show()

這裏就是開始訓練了, 並展現訓練的結果.

  1. 查看結果 剛開始的時候, 還沒學會怎麼模仿:

圖片描述
通過學習改進:

圖片描述

圖片描述

項目地址
查看源碼(請在PC端打開)

聲明
該文章參考了天雨粟:生成對抗網絡(GAN)之MNIST數據生成。

————————————————————————————————————————————
Mo (網址:http://momodel.cn)是一個支持 Python 的人工智能建模平臺,能幫助你快速開發訓練並部署 AI 應用。

相關文章
相關標籤/搜索