GAN神經網絡的keras實現

參考文獻:

主要參考這篇文章 Generative Adversarial Networks, linkpython

爲了方便說明和研究,我這裏只是設計了一個很是簡單的模型,對高斯分佈樣本進行生成。不過從下面的實驗中,我仍是發現了一些很是有用的特色,能夠加深咱們對GAN網絡的瞭解。git

GAN原理

具體原理能夠參考上面的文獻,不過在這裏仍是大概講一下。
其實GAN的原理很是簡單,它有2個子網絡組成,一個是Generator,即生成網絡,它以噪音樣本爲輸入,經過學習到的權重,把噪音轉變(即生成)爲有意義的信號;一個是Discriminator,即判別網絡,他以信號爲輸入(能夠來自generator生成的信號,也能夠是真正的信號),經過學習來判別信號的真假,並輸出一個0-1之間的機率。能夠把Generator比喻爲一個假的印鈔機,而Discriminator則是驗鈔機,他們兩個互相競爭,使得印鈔機愈來愈真,同時驗鈔機也愈來愈準。可是最終咱們是但願Generator愈來愈真,而Discriminator的輸出都是0.5,即難以分辨~~github

而在訓練的時候,則分兩個階段進行,第一個階段是Discriminator的學習,此時固定Generator的權重不變,只更新Discriminator的權重。loss函數是:網絡

$$ \frac{1}{m}\sum_{i=1}^{m}[logD(x^i) + log(1 - D(G(z^i)))] $$app

其中m是batch_size, $x$表示真正的信號,$z$表示噪音樣本。訓練時分別從噪音分佈和真實分佈中選出m個噪音輸入樣本和m個真實信號樣本,經過對以上的loss function最大化更新Discriminator的權重dom

第二個階段是對Generator進行訓練,此時的loss function是:函數

$$ \frac{1}{m}\sum_{i=1}^{m}[log(1 - D(G(z^i)))] $$工具

不過,此時是對loss最小化來更新Generator的權重。學習

另外,這2個階段並非交替進行的,而是執行K次Discriminator的更新,再執行1次Generator的更新。
後面的實驗結果也顯示,K的選擇很是關鍵。測試

具體實現

主要工具是 python + keras,用keras實現一些經常使用的網絡特別容易,好比MLP、word2vec、LeNet、lstm等等,github上都有詳細demo。可是稍微複雜些的就要費些時間本身寫了。不過總體看,依然比用原生tf寫要方便。並且,咱們還能夠把keras當初是學習tf的參考代碼,裏面不少寫法都很是值得借鑑。

廢話很少說了,直接上代碼吧:

GANmodel

只列出最主要的代碼

# 這是針對GAN特殊設計的loss function
def log_loss_discriminator(y_true, y_pred):
    return - K.log(K.maximum(K.epsilon(), y_pred))
    
def log_loss_generator(y_true, y_pred):
    return K.log(K.maximum(K.epsilon(), 1. - y_pred))
    
class GANModel:
    def __init__(self, 
                 input_dim,
                 log_dir = None):
        '''
            __tensor[0]: 定義了discriminateor的表達式,  對y進行判別,true samples
            __tensor[1]: 定義了generator的表達式, 對x進行生成,noise samples
        '''
        if isinstance(input_dim, list):
            input_dim_y, input_dim_x = input_dim[0], input_dim[1]
        elif isinstance(input_dim, int):
            input_dim_x = input_dim_y = input_dim
        else:
            raise ValueError("input_dim should be list or interger, got %r" % input_dim) 
        # 必須使用名字,方便後面分別輸入2個信號
        self.__inputs = [layers.Input(shape=(input_dim_y,), name = "y"), 
                            layers.Input(shape=(input_dim_x,), name = "x")]
        self.__tensors = [None, None] 
        self.log_dir = log_dir
        self._discriminate_layers = []
        self._generate_layers = []
        self.train_status = defaultdict(list)
        
    def add_gen_layer(self, layer):
        self._add_layer(layer, True)
    def add_discr_layer(self, layer):
        self._add_layer(layer)
    def _add_layer(self, layer, for_gen=False):
        idx = 0
        if for_gen:
            self._generate_layers.append(layer)
            idx = 1
        else:
            self._discriminate_layers.append(layer)
        
        if self.__tensors[idx] is None:
            self.__tensors[idx] = layer(self.__inputs[idx])
        else:
            self.__tensors[idx] = layer(self.__tensors[idx])
            
    def compile_discriminateor_model(self, optimizer = optimizers.Adam()):
        if len(self._discriminate_layers) <= 0:
            raise ValueError("you need to build discriminateor model before compile it")
        if len(self._generate_layers) <= 0:
            raise ValueError("you need to build generator model before compile discriminateo model")
        # 經過指定trainable = False,能夠freeze權重的更新。必須放在compile以前
        for l in self._discriminate_layers:
            l.trainable = True
        for l in self._generate_layers:
            l.trainable = False
        discriminateor_out1 = self.__tensors[0]
        discriminateor_out2 = layers.Lambda(lambda y: 1. - y)(self._discriminate_generated())
        # 若是輸出2個信號,keras會分別在各個信號上引用loss function,而後累加,對累加的結果進行
        # minimize 更新。雙下劃線的model是參與訓練的模型。
        self.__discriminateor_model = Model(self.__inputs, [discriminateor_out1, discriminateor_out2])
        self.__discriminateor_model.compile(optimizer, 
                                     loss = log_loss_discriminator)
       
        # 這個纔是真正的discriminator model 
        self.discriminateor_model = Model(self.__inputs[0], self.__tensors[0])
        self.discriminateor_model.compile(optimizer, 
                                     loss = log_loss_discriminator)
        if self.log_dir is not None:
            # 須要安裝pydot和graphviz。沒有的能夠先註釋掉
            plot_model(self.__discriminateor_model, self.log_dir + "/gan_discriminateor_model.png", show_shapes = True) 
        
    def compile_generator_model(self, optimizer = optimizers.Adam()):
        if len(self._discriminate_layers) <= 0:
            raise ValueError("you need to build discriminateor model before compile generator model")
        if len(self._generate_layers) <= 0:
            raise ValueError("you need to build generator model before compile it")
        
        for l in self._discriminate_layers:
            l.trainable = False
        for l in self._generate_layers:
            l.trainable = True
              
        out = self._discriminate_generated()
        self.__generator_model = Model(self.__inputs[1], out)
        self.__generator_model.compile(optimizer, 
                                     loss = log_loss_generator)
        # 這個纔是真正的Generator模型
        self.generator_model = Model(self.__inputs[1], self.__tensors[1])
        if self.log_dir is not None:
            plot_model(self.__generator_model, self.log_dir + "/gan_generator_model.png", show_shapes = True) 

    def train(self, sample_list, epoch = 3, batch_size = 32, step_per = 10, plot=False):
        '''
        step_per: 每隔幾步訓練一次generator,即K
        '''
        sample_noise, sample_true = sample_list["x"], sample_list["y"]
        sample_count = sample_noise.shape[0]
        batch_count = sample_count // batch_size 
        # 這裏比較trick了,由於keras的model必需要一個y。可是gan實際上是沒有y的。只好僞造一個
        # 知足keras的「無理」要求
        psudo_y = np.ones((batch_size, ), dtype = 'float32')
        if plot:
            # plot the real data
            fig = plt.figure()
            ax = fig.add_subplot(1,1,1)
            plt.ion()
            plt.show() 
        for ei in range(epoch):
            for i in range(step_per):
                idx = random.randint(0, batch_count-1)
                batch_noise = sample_noise[idx * batch_size : (idx+1) * batch_size]
                idx = random.randint(0, batch_count-1)
                batch_sample = sample_true[idx * batch_size : (idx+1) * batch_size]
                self.__discriminateor_model.train_on_batch({
                    "y":  batch_sample,
                    "x": batch_noise}, 
                    [psudo_y, psudo_y])

            idx = random.randint(0, batch_count-1)
            batch_noise = sample_noise[idx * batch_size : (idx+1) * batch_size]
            self.__generator_model.train_on_batch(batch_noise, psudo_y)
            
            if plot:
                gen_result = self.generator_model.predict_on_batch(batch_noise)
                self.train_status["gen_result"].append(gen_result)
                dis_result = self.discriminateor_model.predict_on_batch(gen_result)
                self.train_status["dis_result"].append(dis_result)
                freq_g, bin_g = np.histogram(gen_result, density=True)
                # norm to sum1
                freq_g = freq_g * (bin_g[1] - bin_g[0])
                bin_g = bin_g[:-1]
                freq_d, bin_d = np.histogram(batch_sample, density=True)
                freq_d = freq_d * (bin_d[1] - bin_d[0])
                bin_d = bin_d[:-1]
                ax.plot(bin_g, freq_g, 'go-', markersize = 4)
                ax.plot(bin_d, freq_d, 'ko-', markersize = 8)
                gen1d = gen_result.flatten()
                dis1d = dis_result.flatten()
                si = np.argsort(gen1d)
                ax.plot(gen1d[si], dis1d[si], 'r--')
                if (ei+1) % 20 == 0:
                    ax.cla()
                plt.title("epoch = %d" % (ei+1))
                plt.pause(0.05)
        if plot:
            plt.ioff()
            plt.close()

main部分

只列出主要部分:從中能夠看到主要模型結構和參數取值

step_per = 20
    sample_size = args.batch_size * 100

    # 整個測試樣本集合
    noise_dim = 4
    signal_dim = 1
    x = np.random.uniform(-3, 3, size = (sample_size, noise_dim))
    y = np.random.normal(size = (sample_size, signal_dim))
    samples = {"x": x, 
               "y": y}
    
    gan = GANModel([signal_dim, noise_dim], args.log_dir)
    gan.add_discr_layer(layers.Dense(200, activation="relu"))
    gan.add_discr_layer(layers.Dense(50, activation="softmax"))
    gan.add_discr_layer(layers.Lambda(lambda y: K.max(y, axis=-1, keepdims=True),
                                 output_shape = (1,)))

    gan.add_gen_layer(layers.Dense(200, activation="relu"))
    gan.add_gen_layer(layers.Dense(100, activation="relu"))
    gan.add_gen_layer(layers.Dense(50, activation="relu"))
    gan.add_gen_layer(layers.Dense(signal_dim))
    
    gan.compile_generator_model()
    loger.info("compile generator finished")
    gan.compile_discriminateor_model()
    loger.info("compile discriminator finished")
    
    gan.train(samples, args.epoch, args.batch_size, step_per, plot=True)

實驗結果

K的影響

在論文中,做者就提到K對訓練結果影響很大,
使用上面的step_per = 20,我獲得的結果比較理想:
圖片描述

能夠看到,最後Generator生成的數據(綠線)和真實的高斯分佈(黑線)很是接近了,致使Discriminator也變得沒法辨認了(p = 0.5)。

可是把step_per設爲3後,結果就發散的厲害,比較難收斂:
圖片描述
在文章中,做者也提到,Discriminator和Generator必須匹配好,通常要多訓練幾回Discriminator再訓練一次Generator,這是由於Discriminator是Generator的前提,若是D都沒有訓練好,那G的更新方向就會不許。

輸入噪音維度的影響

另外,我還發現,noise_dim對結果影響也很是大。上面的noise_dim = 4, 後面我設置爲1後,最後好像很難收斂到真正的高斯分佈,老是比真的高斯差那麼一點。
圖片描述

因此,個人猜想是:Generator的輸入其實能夠當作是真實信號在其餘維度上的映射,經過模型的學習過程,它找到了兩者的映射關係,因此反過來能夠認爲Generator把真實信號分解到了高維空間裏,此時,固然是維度越高信號被分解的越好,越容易接近真實信號。
並且,從信號擬合角度看,由於我實驗中的高斯信號是非線性的,而使用的激活函數都是線性函數,若是噪音也是1維的,至關於用一堆線性函數去擬合非線性函數,這種狀況必需要在一個更高的維度上才能實現。

訓練一個穩定的GAN網絡是一個很是複雜的過程,所幸已經有大神在這方面作了不少探索。詳細請參考這裏

完整代碼

# demo_gan.py
# -*- encoding: utf8 -*-
'''
GAN網絡Demo
'''
import os
from os import path
import argparse
import logging
import traceback
import random
import pickle
import numpy as np
import tensorflow as tf
from keras import optimizers 
from keras import layers
from keras import callbacks, regularizers, activations
from keras.engine import Model
from keras.utils.vis_utils import plot_model
import keras.backend as K
from collections import defaultdict
from matplotlib import pyplot as plt
import app_logger

loger = logging.getLogger(__name__)

# 注意pred不能爲負數,由於pred是一個機率。因此最後一個激活函數的選擇要注意
def log_loss_discriminator(y_true, y_pred):
    return - K.log(K.maximum(K.epsilon(), y_pred))
    
def log_loss_generator(y_true, y_pred):
    return K.log(K.maximum(K.epsilon(), 1. - y_pred))

class GANModel:
    def __init__(self, 
                 input_dim,
                 log_dir = None):
        '''
            __tensor[0]: 定義了discriminateor的表達式
            __tensor[1]: 定義了generator的表達式
        '''
        # discriminateor 對y進行判別,true samples
        # generator 對x進行生成,noise samples
        if isinstance(input_dim, list):
            input_dim_y, input_dim_x = input_dim[0], input_dim[1]
        elif isinstance(input_dim, int):
            input_dim_x = input_dim_y = input_dim
        else:
            raise ValueError("input_dim should be list or interger, got %r" % input_dim) 
    
        self.__inputs = [layers.Input(shape=(input_dim_y,), name = "y"), 
                            layers.Input(shape=(input_dim_x,), name = "x")]
        self.__tensors = [None, None] 
        self.log_dir = log_dir
        self._discriminate_layers = []
        self._generate_layers = []
        self.train_status = defaultdict(list)
        
    def add_gen_layer(self, layer):
        self._add_layer(layer, True)
    def add_discr_layer(self, layer):
        self._add_layer(layer)
    def _add_layer(self, layer, for_gen=False):
        idx = 0
        if for_gen:
            self._generate_layers.append(layer)
            idx = 1
        else:
            self._discriminate_layers.append(layer)
        
        if self.__tensors[idx] is None:
            self.__tensors[idx] = layer(self.__inputs[idx])
        else:
            self.__tensors[idx] = layer(self.__tensors[idx])
            
    def compile_discriminateor_model(self, optimizer = optimizers.Adam()):
        if len(self._discriminate_layers) <= 0:
            raise ValueError("you need to build discriminateor model before compile it")
        if len(self._generate_layers) <= 0:
            raise ValueError("you need to build generator model before compile discriminateo model")
        
        for l in self._discriminate_layers:
            l.trainable = True
        for l in self._generate_layers:
            l.trainable = False
        discriminateor_out1 = self.__tensors[0]
        discriminateor_out2 = layers.Lambda(lambda y: 1. - y)(self._discriminate_generated())
        self.__discriminateor_model = Model(self.__inputs, [discriminateor_out1, discriminateor_out2])
        self.__discriminateor_model.compile(optimizer, 
                                     loss = log_loss_discriminator)
       
        # 這個纔是須要的discriminateor model 
        self.discriminateor_model = Model(self.__inputs[0], self.__tensors[0])
        self.discriminateor_model.compile(optimizer, 
                                     loss = log_loss_discriminator)
        #if self.log_dir is not None:
        #    plot_model(self.__discriminateor_model, self.log_dir + "/gan_discriminateor_model.png", show_shapes = True) 
        
    def compile_generator_model(self, optimizer = optimizers.Adam()):
        if len(self._discriminate_layers) <= 0:
            raise ValueError("you need to build discriminateor model before compile generator model")
        if len(self._generate_layers) <= 0:
            raise ValueError("you need to build generator model before compile it")
        
        for l in self._discriminate_layers:
            l.trainable = False
        for l in self._generate_layers:
            l.trainable = True
              
        out = self._discriminate_generated()
        self.__generator_model = Model(self.__inputs[1], out)
        self.__generator_model.compile(optimizer, 
                                     loss = log_loss_generator)
        # 這個纔是真正須要的模型
        self.generator_model = Model(self.__inputs[1], self.__tensors[1])
        #if self.log_dir is not None:
        #    plot_model(self.__generator_model, self.log_dir + "/gan_generator_model.png", show_shapes = True) 

    def train(self, sample_list, epoch = 3, batch_size = 32, step_per = 10, plot=False):
        '''
        step_per: 每隔幾步訓練一次generator
        '''
        sample_noise, sample_true = sample_list["x"], sample_list["y"]
        sample_count = sample_noise.shape[0]
        batch_count = sample_count // batch_size 
        psudo_y = np.ones((batch_size, ), dtype = 'float32')
        if plot:
            # plot the real data
            fig = plt.figure()
            ax = fig.add_subplot(1,1,1)
            plt.ion()
            plt.show() 
        for ei in range(epoch):
            for i in range(step_per):
                idx = random.randint(0, batch_count-1)
                batch_noise = sample_noise[idx * batch_size : (idx+1) * batch_size]
                idx = random.randint(0, batch_count-1)
                batch_sample = sample_true[idx * batch_size : (idx+1) * batch_size]
                self.__discriminateor_model.train_on_batch({
                    "y":  batch_sample,
                    "x": batch_noise}, 
                    [psudo_y, psudo_y])

            idx = random.randint(0, batch_count-1)
            batch_noise = sample_noise[idx * batch_size : (idx+1) * batch_size]
            self.__generator_model.train_on_batch(batch_noise, psudo_y)
            
            if plot:
                gen_result = self.generator_model.predict_on_batch(batch_noise)
                self.train_status["gen_result"].append(gen_result)
                dis_result = self.discriminateor_model.predict_on_batch(gen_result)
                self.train_status["dis_result"].append(dis_result)
                freq_g, bin_g = np.histogram(gen_result, density=True)
                # norm to sum1
                freq_g = freq_g * (bin_g[1] - bin_g[0])
                bin_g = bin_g[:-1]
                freq_d, bin_d = np.histogram(batch_sample, density=True)
                freq_d = freq_d * (bin_d[1] - bin_d[0])
                bin_d = bin_d[:-1]
                ax.plot(bin_g, freq_g, 'go-', markersize = 4)
                ax.plot(bin_d, freq_d, 'ko-', markersize = 8)
                gen1d = gen_result.flatten()
                dis1d = dis_result.flatten()
                si = np.argsort(gen1d)
                ax.plot(gen1d[si], dis1d[si], 'r--')
                if (ei+1) % 20 == 0:
                    ax.cla()
                plt.title("epoch = %d" % (ei+1))
                plt.pause(0.05)
        if plot:
            plt.ioff()
            plt.close()
            
            
    def save_model(self, path_dir):
        self.generator_model.save(path_dir + "/gan_generator.h5")
        self.discriminateor_model.save(path_dir + "/gan_discriminateor.h5")
    
    def load_model(self, path_dir):
        from keras.models import load_model
        custom_obj = {
            "log_loss_discriminateor": log_loss_discriminateor,
            "log_loss_generator": log_loss_generator}
        self.generator_model = load_model(path_dir + "/gan_generator.h5", custom_obj)
        self.discriminateor_model = load_model(path_dir + "/gan_discriminateor.h5", custom_obj)
    
    def _discriminate_generated(self):
        # 必須每次從新生成一下 
        disc_t = self.__tensors[1]
        for l in self._discriminate_layers:
            disc_t = l(disc_t)            
        return disc_t
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser("""gan model demo (gaussian sample)""")
    parser.add_argument("-m", "--model_dir")
    parser.add_argument("-log", "--log_dir")
    parser.add_argument("-b", "--batch_size", type = int, default = 32)
    parser.add_argument("-log_lvl", "--log_lvl", default = "info",
                        metavar = "能夠指定INFO,DEBUG,WARN, ERROR")
    parser.add_argument("-e", "--epoch", type = int, default = 10)
    
    args = parser.parse_args()
    
    log_lvl = {"info": logging.INFO,
               "debug": logging.DEBUG,
               "warn": logging.WARN,
               "warning": logging.WARN,
               "error": logging.ERROR,
               "err": logging.ERROR}[args.log_lvl.lower()]
    app_logger.init(log_lvl)
        
    loger.info("args: %r" % args)
    step_per = 20
    sample_size = args.batch_size * 100

    # 整個測試樣本集合
    noise_dim = 4
    signal_dim = 1
    x = np.random.uniform(-3, 3, size = (sample_size, noise_dim))
    y = np.random.normal(size = (sample_size, signal_dim))
    samples = {"x": x, 
               "y": y}
    
    gan = GANModel([signal_dim, noise_dim], args.log_dir)
    gan.add_discr_layer(layers.Dense(200, activation="relu"))
    gan.add_discr_layer(layers.Dense(50, activation="softmax"))
    gan.add_discr_layer(layers.Lambda(lambda y: K.max(y, axis=-1, keepdims=True),
                                 output_shape = (1,)))

    gan.add_gen_layer(layers.Dense(200, activation="relu"))
    gan.add_gen_layer(layers.Dense(100, activation="relu"))
    gan.add_gen_layer(layers.Dense(50, activation="relu"))
    gan.add_gen_layer(layers.Dense(signal_dim))
    
    gan.compile_generator_model()
    loger.info("compile generator finished")
    gan.compile_discriminateor_model()
    loger.info("compile discriminator finished")
    
    gan.train(samples, args.epoch, args.batch_size, step_per, plot=True)
    gen_results = gan.train_status["gen_result"]
    dis_results = gan.train_status["dis_result"]

    gen_result = gen_results[-1]
    dis_result = dis_results[-1]
    freq_g, bin_g = np.histogram(gen_result, density=True)
    # norm to sum1
    freq_g = freq_g * (bin_g[1] - bin_g[0])
    bin_g = bin_g[:-1]
    freq_d, bin_d = np.histogram(y, bins = 100, density=True)
    freq_d = freq_d * (bin_d[1] - bin_d[0])
    bin_d = bin_d[:-1]
    plt.plot(bin_g, freq_g, 'go-', markersize = 4)
    plt.plot(bin_d, freq_d, 'ko-', markersize = 8)
    gen1d = gen_result.flatten()
    dis1d = dis_result.flatten()
    si = np.argsort(gen1d)
    plt.plot(gen1d[si], dis1d[si], 'r--')
    plt.savefig("img/gan_results.png")
    if not path.exists(args.model_dir):
        os.mkdir(args.model_dir)
    gan.save_model(args.model_dir)


# app_logger.py
import logging

def init(lvl=logging.DEBUG):
    log_handler = logging.StreamHandler()
    # create formatter
    formatter = logging.Formatter('[%(asctime)s] %(levelname)s %(filename)s:%(funcName)s:%(lineno)d > %(message)s')
    log_handler.setFormatter(formatter)
    logging.basicConfig(level = lvl, handlers = [log_handler])
相關文章
相關標籤/搜索