生成對抗網絡GAN詳解與代碼

1.GAN的基本原理其實很是簡單,這裏以生成圖片爲例進行說明。假設咱們有兩個網絡,G(Generator)和D(Discriminator)。正如它的名字所暗示的那樣,它們的功能分別是:git

  • G是一個生成圖片的網絡,它接收一個隨機的噪聲z,經過這個噪聲生成圖片,記作G(z)。網絡

  • D是一個判別網絡,判別一張圖片是否是「真實的」。它的輸入參數是x,x表明一張圖片,輸出D(x)表明x爲真實圖片的機率,若是爲1,就表明100%是真實的圖片,而輸出爲0,就表明不多是真實的圖片。dom

在訓練過程當中,生成網絡G的目標就是儘可能生成真實的圖片去欺騙判別網絡D。而D的目標就是儘可能把G生成的圖片和真實的圖片分別開來。這樣,G和D構成了一個動態的「博弈過程」函數

最後博弈的結果是什麼?在最理想的狀態下,G能夠生成足以「以假亂真」的圖片G(z)。對於D來講,它難以斷定G生成的圖片到底是不是真實的,所以D(G(z)) = 0.5。工具

這樣咱們的目的就達成了:咱們獲得了一個生成式的模型G,它能夠用來生成圖片。學習

以上只是大體說了一下GAN的核心原理,如何用數學語言描述呢?這裏直接摘錄論文裏的公式:優化

(1)優化D:編碼

 

優化第一項是真是樣本x輸入的時候,結果越大越好;對於噪聲等的輸入z,生成的假樣本G(z)要越小越好spa

(2)優化G:.net

 

優化生成器時和真是樣本不要緊,故不須要考慮;這時候只有假樣本,但生成器但願假樣本越逼真越好(接近1),故D(G(z)越大越好,則最小化1-D(G(z))

 

 2.GAN的特色:

    (1)相比較傳統的模型,他存在兩個不一樣的網絡,而不是單一的網絡,而且訓練方式採用的是對抗訓練方式

    (2)GAN中G的梯度更新信息來自判別器D,而不是來自數據樣本

3. GAN 的優勢:

    (1) GAN是一種生成式模型,相比較其餘生成模型(玻爾茲曼機和GSNs)只用到了反向傳播,而不須要複雜的馬爾科夫鏈

    (2)相比其餘全部模型, GAN能夠產生更加清晰,真實的樣本

    (3)GAN採用的是一種無監督的學習方式訓練,能夠被普遍用在無監督學習和半監督學習領域

    (4)相比於變分自編碼器, GANs沒有引入任何決定性偏置( deterministic bias),變分方法引入決定性偏置,由於他們優化對數似然的下界,而不是似然度自己,這看起來致使了VAEs生成的實例比GANs更模糊

    (5)相比VAE, GANs沒有變分下界,若是鑑別器訓練良好,那麼生成器能夠完美的學習到訓練樣本的分佈.換句話說,GANs是漸進一致的,可是VAE是有誤差的

    (6)GAN應用到一些場景上,好比圖片風格遷移,超分辨率,圖像補全,去噪,避免了損失函數設計的困難,無論三七二十一,只要有一個的基準,直接上判別器,剩下的就交給對抗訓練了。

 4. GAN的缺點:

    (1)訓練GAN須要達到納什均衡,有時候能夠用梯度降低法作到,有時候作不到.咱們尚未找到很好的達到納什均衡的方法,因此訓練GAN相比VAE或者PixelRNN是不穩定的,但我認爲在實踐中它仍是比訓練玻爾茲曼機穩定的多

    (2)GAN不適合處理離散形式的數據,好比文本

    (3)GAN存在訓練不穩定、梯度消失、模式崩潰的問題(目前已解決)

5.爲何GAN中的優化器不經常使用SGD

    (1)SGD容易震盪,容易使GAN訓練不穩定,

    (2)GAN的目的是在高維非凸的參數空間中找到納什均衡點,GAN的納什均衡點是一個鞍點,可是SGD只會找到局部極小值,由於SGD解決的是一個尋找最小值的問題,GAN是一個博弈問題。

6.訓練GAN的一些技巧

(1). 輸入規範化到(-1,1)之間,最後一層的激活函數使用tanh(BEGAN除外)

(2). 使用wassertein GAN的損失函數

(3). 若是有標籤數據的話,儘可能使用標籤,也有人提出使用反轉標籤效果很好,另外使用標籤平滑,單邊標籤平滑或者雙邊標籤平滑

(4). 使用mini-batch norm, 若是不用batch norm 可使用instance norm 或者weight norm

(5). 避免使用RELU和pooling層,減小稀疏梯度的可能性,可使用leakrelu激活函數

(6). 優化器儘可能選擇ADAM,學習率不要設置太大,初始1e-4能夠參考,另外能夠隨着訓練進行不斷縮小學習率,

(7). 給D的網絡層增長高斯噪聲,至關因而一種正則

 

7.GAN實戰

import tensorflow as tf #導入tensorflow
from tensorflow.examples.tutorials.mnist import input_data #導入手寫數字數據集
import numpy as np #導入numpy
import matplotlib.pyplot as plt #plt是繪圖工具,在訓練過程當中用於輸出可視化結果
import matplotlib.gridspec as gridspec #gridspec是圖片排列工具,在訓練過程當中用於輸出可視化結果
import os #導入os
 
    
def xavier_init(size): #初始化參數時使用的xavier_init函數
    in_dim = size[0] 
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.) #初始化標準差
    return tf.random_normal(shape=size, stddev=xavier_stddev) #返回初始化的結果

X = tf.placeholder(tf.float32, shape=[None, 784]) #X表示真的樣本(即真實的手寫數字)

D_W1 = tf.Variable(xavier_init([784, 128])) #表示使用xavier方式初始化的判別器的D_W1參數,是一個784行128列的矩陣
D_b1 = tf.Variable(tf.zeros(shape=[128])) #表示全零方式初始化的判別器的D_1參數,是一個長度爲128的向量 
D_W2 = tf.Variable(xavier_init([128, 1])) #表示使用xavier方式初始化的判別器的D_W2參數,是一個128行1列的矩陣
D_b2 = tf.Variable(tf.zeros(shape=[1])) ##表示全零方式初始化的判別器的D_1參數,是一個長度爲1的向量
theta_D = [D_W1, D_W2, D_b1, D_b2] #theta_D表示判別器的可訓練參數集合

Z = tf.placeholder(tf.float32, shape=[None, 100]) #Z表示生成器的輸入(在這裏是噪聲),是一個N列100行的矩陣
 
G_W1 = tf.Variable(xavier_init([100, 128])) #表示使用xavier方式初始化的生成器的G_W1參數,是一個100行128列的矩陣
G_b1 = tf.Variable(tf.zeros(shape=[128])) #表示全零方式初始化的生成器的G_b1參數,是一個長度爲128的向量 
G_W2 = tf.Variable(xavier_init([128, 784])) #表示使用xavier方式初始化的生成器的G_W2參數,是一個128行784列的矩陣
G_b2 = tf.Variable(tf.zeros(shape=[784])) #表示全零方式初始化的生成器的G_b2參數,是一個長度爲784的向量
theta_G = [G_W1, G_W2, G_b1, G_b2] #theta_G表示生成器的可訓練參數集合

def sample_Z(m, n): #生成維度爲[m, n]的隨機噪聲做爲生成器G的輸入
    return np.random.uniform(-1., 1., size=[m, n])

def generator(z): #生成器,z的維度爲[N, 100]
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1) #輸入的隨機噪聲乘以G_W1矩陣加上偏置G_b1,G_h1維度爲[N, 128]
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2 #G_h1乘以G_W2矩陣加上偏置G_b2,G_log_prob維度爲[N, 784]
    G_prob = tf.nn.sigmoid(G_log_prob) #G_log_prob通過一個sigmoid函數,G_prob維度爲[N, 784] 
    return G_prob #返回G_prob

def discriminator(x): #判別器,x的維度爲[N, 784]
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1) #輸入乘以D_W1矩陣加上偏置D_b1,D_h1維度爲[N, 128]
    D_logit = tf.matmul(D_h1, D_W2) + D_b2 #D_h1乘以D_W2矩陣加上偏置D_b2,D_logit維度爲[N, 1]
    D_prob = tf.nn.sigmoid(D_logit) #D_logit通過一個sigmoid函數,D_prob維度爲[N, 1]
    return D_prob, D_logit #返回D_prob, D_logit

G_sample = generator(Z) #取得生成器的生成結果
D_real, D_logit_real = discriminator(X) #取得判別器判別的真實手寫數字的結果
D_fake, D_logit_fake = discriminator(G_sample) #取得判別器判別的生成的手寫數字的結果

#對判別器對真實樣本的判別結果計算偏差(將結果與1比較)
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, targets=tf.ones_like(D_logit_real))) 
#對判別器對虛假樣本(即生成器生成的手寫數字)的判別結果計算偏差(將結果與0比較)
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, targets=tf.zeros_like(D_logit_fake))) 
#判別器的偏差
D_loss = D_loss_real + D_loss_fake 
#生成器的偏差(將判別器返回的對虛假樣本的判別結果與1比較)
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, targets=tf.ones_like(D_logit_fake))) 

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) #mnist是手寫數字數據集

D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) #判別器的訓練器
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G) #生成器的訓練器

mb_size = 128 #訓練的batch_size
Z_dim = 100 #生成器輸入的隨機噪聲的列的維度
  
sess = tf.Session() #會話層
sess.run(tf.initialize_all_variables()) #初始化全部可訓練參數

def plot(samples): #保存圖片時使用的plot函數
    fig = plt.figure(figsize=(4, 4)) #初始化一個4行4列包含16張子圖像的圖片
    gs = gridspec.GridSpec(4, 4) #調整子圖的位置
    gs.update(wspace=0.05, hspace=0.05) #置子圖間的間距
    for i, sample in enumerate(samples): #依次將16張子圖填充進須要保存的圖像
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r') 
    return fig


path = '/data/User/zcc/' #保存可視化結果的路徑
i = 0 #訓練過程當中保存的可視化結果的索引 
for it in range(1000000): #訓練100萬次
    if it % 1000 == 0: #每訓練1000次就保存一下結果
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})
        fig = plot(samples) #經過plot函數生成可視化結果
        plt.savefig(path+'out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight') #保存可視化結果
        i += 1
        plt.close(fig)
 
    X_mb, _ = mnist.train.next_batch(mb_size) #獲得訓練一個batch所需的真實手寫數字(做爲判別器的輸入)
 
    #下面是獲得訓練一次的結果,經過sess來run出來
    _, D_loss_curr, D_loss_real, D_loss_fake, D_loss = sess.run([D_solver, D_loss, D_loss_real, D_loss_fake, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})
 
    if it % 1000 == 0: #每訓練1000次輸出一下結果
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

 參考博客:

https://blog.csdn.net/m0_37407756/article/details/75309670

http://www.javashuo.com/article/p-fsbucsnk-ms.html

相關文章
相關標籤/搜索