不要慫,就是GAN (生成式對抗網絡) (四):訓練和測試 GAN

在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同時新建文件夾 logs 和文件夾 samples,前者用來保存訓練過程當中的日誌和模型,後者用來保存訓練過程當中採樣器的採樣圖片,在 train.py 中輸入以下代碼:html

# -*- coding: utf-8 -*-
import tensorflow as tf
import os

from read_data import *
from utils import *
from ops import *
from model import *
from model import BATCH_SIZE


def train():

    # 設置 global_step ,用來記錄訓練過程當中的 step        
    global_step = tf.Variable(0, name = 'global_step', trainable = False)
    # 訓練過程當中的日誌保存文件
    train_dir = '/home/your_name/TensorFlow/DCGAN/logs'

    # 放置三個 placeholder,y 表示約束條件,images 表示送入判別器的圖片,
    # z 表示隨機噪聲
    y= tf.placeholder(tf.float32, [BATCH_SIZE, 10], name='y')
    images = tf.placeholder(tf.float32, [64, 28, 28, 1], name='real_images')
    z = tf.placeholder(tf.float32, [None, 100], name='z')

    # 由生成器生成圖像 G
    G = generator(z, y)
    # 真實圖像送入判別器
    D, D_logits  = discriminator(images, y)
    # 採樣器採樣圖像
    samples = sampler(z, y)
    # 生成圖像送入判別器
    D_, D_logits_ = discriminator(G, y, reuse = True)
    
    # 損失計算
    d_loss_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(D_logits, tf.ones_like(D)))
    d_loss_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(D_logits_, tf.zeros_like(D_)))
    d_loss = d_loss_real + d_loss_fake
    g_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(D_logits_, tf.ones_like(D_)))

    # 總結操做
    z_sum = tf.histogram_summary("z", z)
    d_sum = tf.histogram_summary("d", D)
    d__sum = tf.histogram_summary("d_", D_)
    G_sum = tf.image_summary("G", G)

    d_loss_real_sum = tf.scalar_summary("d_loss_real", d_loss_real)
    d_loss_fake_sum = tf.scalar_summary("d_loss_fake", d_loss_fake)
    d_loss_sum = tf.scalar_summary("d_loss", d_loss)                                                
    g_loss_sum = tf.scalar_summary("g_loss", g_loss)
    
    # 合併各自的總結
    g_sum = tf.merge_summary([z_sum, d__sum, G_sum, d_loss_fake_sum, g_loss_sum])
    d_sum = tf.merge_summary([z_sum, d_sum, d_loss_real_sum, d_loss_sum])

    # 生成器和判別器要更新的變量,用於 tf.train.Optimizer 的 var_list
    t_vars = tf.trainable_variables()
    d_vars = [var for var in t_vars if 'd_' in var.name]
    g_vars = [var for var in t_vars if 'g_' in var.name]

    saver = tf.train.Saver()
    
    # 優化算法採用 Adam
    d_optim = tf.train.AdamOptimizer(0.0002, beta1 = 0.5) \
                .minimize(d_loss, var_list = d_vars, global_step = global_step)
    g_optim = tf.train.AdamOptimizer(0.0002, beta1 = 0.5) \
                .minimize(g_loss, var_list = g_vars, global_step = global_step)
        
    
    os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.2
    sess = tf.InteractiveSession(config=config)

    init = tf.initialize_all_variables()   
    writer = tf.train.SummaryWriter(train_dir, sess.graph)
    
    # 這個本身理解吧
    data_x, data_y = read_data()
    sample_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
#    sample_images = data_x[0: 64]
    sample_labels = data_y[0: 64]
    sess.run(init)    
    
    # 循環 25 個 epoch 訓練網絡
    for epoch in range(25):
        batch_idxs = 1093
        for idx in range(batch_idxs):        
            batch_images = data_x[idx*64: (idx+1)*64]
            batch_labels = data_y[idx*64: (idx+1)*64]
            batch_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))            
            
            # 更新 D 的參數
            _, summary_str = sess.run([d_optim, d_sum], 
                                      feed_dict = {images: batch_images, 
                                                   z: batch_z, 
                                                   y: batch_labels})
            writer.add_summary(summary_str, idx+1)

            # 更新 G 的參數
            _, summary_str = sess.run([g_optim, g_sum], 
                                      feed_dict = {z: batch_z, 
                                                   y: batch_labels})
            writer.add_summary(summary_str, idx+1)

            # 更新兩次 G 的參數確保網絡的穩定
            _, summary_str = sess.run([g_optim, g_sum], 
                                      feed_dict = {z: batch_z,
                                                   y: batch_labels})
            writer.add_summary(summary_str, idx+1)
            
            # 計算訓練過程當中的損失,打印出來
            errD_fake = d_loss_fake.eval({z: batch_z, y: batch_labels})
            errD_real = d_loss_real.eval({images: batch_images, y: batch_labels})
            errG = g_loss.eval({z: batch_z, y: batch_labels})

            if idx % 20 == 0:
                print("Epoch: [%2d] [%4d/%4d] d_loss: %.8f, g_loss: %.8f" \
                        % (epoch, idx, batch_idxs, errD_fake+errD_real, errG))
            
            # 訓練過程當中,用採樣器採樣,而且保存採樣的圖片到 
            # /home/your_name/TensorFlow/DCGAN/samples/
            if idx % 100 == 1:
                sample = sess.run(samples, feed_dict = {z: sample_z, y: sample_labels})
                samples_path = '/home/your_name/TensorFlow/DCGAN/samples/'
                save_images(sample, [8, 8], 
                            samples_path + 'test_%d_epoch_%d.png' % (epoch, idx))
                print 'save down'
            
            # 每過 500 次迭代,保存一次模型
            if idx % 500 == 2:
                checkpoint_path = os.path.join(train_dir, 'DCGAN_model.ckpt')
                saver.save(sess, checkpoint_path, global_step = idx+1)
                
    sess.close()


if __name__ == '__main__':
    train()    

 輸入完成後點擊運行,運行過程當中,能夠看到,生成的每一個圖片對應行對應列都是同樣的數字,這是由於咱們加了條件約束;採樣器 sampler 採樣的圖片被保存在 samples 文件夾下,由模糊到清晰,由剛開始的噪聲,慢慢變成手寫字符,最後徹底區分不出來是生成圖片仍是真實圖片,反正我是區分不出來,you can you up。git

 

  

   

 與此同時,要是在訓練的時候打開 TensorBoard,能夠看到 D 的分佈,大體在趨於 0.5 左右的附件徘徊,說明判別器 D 已經趨於判別不出來了,只能隨機猜想,正確率大體 0.5。github

 

 

 

 

 

 

 

 

 

 

 

 

 

 

講道理,咱們的 GAN 到這一步,已經算是完成了,測試的過程,咱們已經在訓練的時候經過採樣完成了,若是嫌不夠,非要單獨寫個測試的文件,也不是不能夠:算法

在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 eval.py 和文件夾 eval,eval 文件夾用來保存測試結果圖片,在 eval.py 中輸入以下代碼:網絡

 

# -*- coding: utf-8 -*-
import tensorflow as tf
import os

from read_data import *
from utils import *
from ops import *
from model import *
from model import BATCH_SIZE


def eval():
    # 用於存放測試圖片
    test_dir = '/home/your_name/TensorFlow/DCGAN/eval/'
    # 今後處加載模型
    checkpoint_dir = '/home/your_name/TensorFlow/DCGAN/logs/'
    
    y= tf.placeholder(tf.float32, [BATCH_SIZE, 10], name='y')
    z = tf.placeholder(tf.float32, [None, 100], name='z')
    
    G = generator(z, y)    
    data_x, data_y = read_data()
    sample_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
    sample_labels = data_y[120: 184]
    
    # 讀取 ckpt 須要 sess,saver
    print("Reading checkpoints...")
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    
    # saver
    saver = tf.train.Saver(tf.all_variables())
    
    # sess
    os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.2
    sess = tf.InteractiveSession(config=config)
    
    # 從保存的模型中恢復變量
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)        
        saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
    
    # 用恢復的變量進行生成器的測試
    test_sess = sess.run(G, feed_dict = {z: sample_z, y: sample_labels})
    
    # 保存測試的生成器圖片到特定文件夾
    save_images(test_sess, [8, 8], test_dir + 'test_%d.png' % 500)
    
    sess.close()


if  __name__ == '__main__':

    eval()    

 點擊運行,在 eval 文件夾下生成test_500.png 文件,能夠看到,生成器 G 已經能夠生成不錯的結果。dom

 

訓練測試完,能夠打開 TensorBoard 查看網絡的 Graph,能夠看到,因爲沒有細緻採用 namespace 和 variable_scope ,畫出來的 Graph 比較凌亂,只能依稀的看出來網絡的一些結構。函數

 

至此,咱們的 TensorFlow GAN 工做基本完成,細心的朋友會發現,咱們的程序存在如下幾個問題:測試

1)在寫 eval() 函數的時候,對於生成函數 generator(),沒有指定 train = False,也就是在 BN 層,沒有體現出訓練和測試的區別;優化

2)在個人這篇 http://www.cnblogs.com/Charles-Wan/p/6197019.html 博客中,提到了我採用了 tfrecords 進行 GAN 數據的輸入處理,可是此程序並無體現出來;spa

3)沒有細緻的採用 namespace 和 variable_scope ,畫出來的 Graph 比較凌亂;

4)程序中太多不明含義的數字,路徑名字全都採用絕對路徑;

5)訓練過程當中不能斷點續訓練等。

針對以上問題,咱們在下一節的不加約束 GAN 上將進行改進。

 

 

參考文獻:

1. https://github.com/carpedm20/DCGAN-tensorflow

相關文章
相關標籤/搜索