不要慫,就是GAN (生成式對抗網絡) (三):判別器和生成器 TensorFlow Model

在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 utils.py,輸入以下代碼:html

import scipy.misc
import numpy as np

# 保存圖片函數
def save_images(images, size, path):
    
    """
    Save the samples images
    The best size number is
            int(max(sqrt(image.shape[0]),sqrt(image.shape[1]))) + 1
    example:
        The batch_size is 64, then the size is recommended [8, 8]
        The batch_size is 32, then the size is recommended [6, 6]
    """

    # 圖片歸一化,主要用於生成器輸出是 tanh 形式的歸一化
    img = (images + 1.0) / 2.0
    h, w = img.shape[1], img.shape[2]

    # 產生一個大畫布,用來保存生成的 batch_size 個圖像
    merge_img = np.zeros((h * size[0], w * size[1], 3))

    # 循環使得畫布特定地方值爲某一幅圖像的值
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        merge_img[j*h:j*h+h, i*w:i*w+w, :] = image
    
    # 保存畫布
    return scipy.misc.imsave(path, merge_img)

 

這個函數的做用是在訓練的過程當中保存採樣生成的圖片。git

 

在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 model.py,定義生成器,判別器和訓練過程當中的採樣網絡,在 model.py 輸入以下代碼:github

import tensorflow as tf
from ops import *

BATCH_SIZE = 64

# 定義生成器
def generator(z, y, train = True):
    # y 是一個 [BATCH_SIZE, 10] 維的向量,把 y 轉成四維張量
    yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
    # 把 y 做爲約束條件和 z 拼接起來
    z = tf.concat(1, [z, y], name = 'z_concat_y')
    # 通過一個全鏈接,BN 和激活層 ReLu
    h1 = tf.nn.relu(batch_norm_layer(fully_connected(z, 1024, 'g_fully_connected1'), 
                                     is_train = train, name = 'g_bn1'))
    # 把約束條件和上一層拼接起來
    h1 = tf.concat(1, [h1, y], name = 'active1_concat_y')
    
    h2 = tf.nn.relu(batch_norm_layer(fully_connected(h1, 128 * 49, 'g_fully_connected2'), 
                                     is_train = train, name = 'g_bn2'))
    h2 = tf.reshape(h2, [64, 7, 7, 128], name = 'h2_reshape')
    # 把約束條件和上一層拼接起來
    h2 = conv_cond_concat(h2, yb, name = 'active2_concat_y')

    h3 = tf.nn.relu(batch_norm_layer(deconv2d(h2, [64,14,14,128], 
                                              name = 'g_deconv2d3'), 
                                              is_train = train, name = 'g_bn3'))
    h3 = conv_cond_concat(h3, yb, name = 'active3_concat_y')
    
    # 通過一個 sigmoid 函數把值歸一化爲 0~1 之間,
    h4 = tf.nn.sigmoid(deconv2d(h3, [64, 28, 28, 1], 
                                name = 'g_deconv2d4'), name = 'generate_image')
    
    return h4

# 定義判別器    
def discriminator(image, y, reuse = False):
    
    # 由於真實數據和生成數據都要通過判別器,因此須要指定 reuse 是否可用
    if reuse:
        tf.get_variable_scope().reuse_variables()

    # 同生成器同樣,判別器也須要把約束條件串聯進來
    yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
    x = conv_cond_concat(image, yb, name = 'image_concat_y')
    
    # 卷積,激活,串聯條件。
    h1 = lrelu(conv2d(x, 11, name = 'd_conv2d1'), name = 'lrelu1')
    h1 = conv_cond_concat(h1, yb, name = 'h1_concat_yb')
    
    h2 = lrelu(batch_norm_layer(conv2d(h1, 74, name = 'd_conv2d2'), 
                                name = 'd_bn2'), name = 'lrelu2')
    h2 = tf.reshape(h2, [BATCH_SIZE, -1], name = 'reshape_lrelu2_to_2d')
    h2 = tf.concat(1, [h2, y], name = 'lrelu2_concat_y')

    h3 = lrelu(batch_norm_layer(fully_connected(h2, 1024, name = 'd_fully_connected3'), 
                                name = 'd_bn3'), name = 'lrelu3')
    h3 = tf.concat(1,[h3, y], name = 'lrelu3_concat_y')
    
    # 全鏈接層,輸出覺得 loss 值
    h4 = fully_connected(h3, 1, name = 'd_result_withouts_sigmoid')
    
    return tf.nn.sigmoid(h4, name = 'discriminator_result_with_sigmoid'), h4
    
# 定義訓練過程當中的採樣函數    
def sampler(z, y, train = True):
    tf.get_variable_scope().reuse_variables()
    return generator(z, y, train = train)

 

能夠看到,生成器由 7 × 7  變爲 14 × 14 再變爲 28 × 28大小,每一層都加入了約束條件 y,完美的詮釋了論文所給出的網絡,之因此要加入 is_train 參數,是因爲 Batch_norm 層中訓練和測試的時候的過程是不一樣的,用這個參數區分訓練和測試,生成器的最後一層,用了一個 sigmoid 函數把值歸一化到 0~1 之間,若是是不加約束的網絡,則用 tanh 函數,因此在 save_images 函數中要用到語句:img = (images + 1.0) / 2.0。網絡

sampler 函數的做用是在訓練過程當中對生成器生成的圖片進行採樣,因此這個函數必須指定 reuse 可用,關於 reuse 說明,請看:http://www.cnblogs.com/Charles-Wan/p/6200446.html。函數

 

 

參考資料:測試

1. https://github.com/carpedm20/DCGAN-tensorflowspa

相關文章
相關標籤/搜索