除VAE以外,生成式對抗網絡(Generative Adversarial Nets,GAN)也是一種很是流行的無監督生成式模型git
GAN中主要包括兩個核心網絡github
GAN的訓練很是困難,有不少細節須要注意,才能生成質量較高的圖片網絡
strides
爲2的卷積代替池化這裏咱們以MNIST
爲例,經過TensorFlow
實現GAN,因爲用到深度卷積神經網絡,因此也稱做DCGAN(Deep Convolutional GAN)app
對於一個服從隨機分佈的噪音z,生成器經過一個複雜的映射函數生成假的樣本dom
$$ \hat{x}=G(z;\theta_g) $$ide
判別器則使用另外一個複雜的映射函數,對於真實樣本或假的樣本,輸出一個0至1之間的值,越大表示越有多是真實的樣本函數
$$ s=D(x;\theta_d) $$學習
總的目標函數以下優化
$$ \min_{G}\max_{D} V(D,G)=\mathbb{E}{x\sim p{data}}[\log D(x)] + \mathbb{E}_{z\sim p_z}[\log(1-D(G(z)))] $$ui
加載庫
# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import matplotlib.pyplot as plt %matplotlib inline import os, imageio
加載數據
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data')
定義一些常量、網絡輸入、輔助函數
batch_size = 100 z_dim = 100 OUTPUT_DIR = 'samples' if not os.path.exists(OUTPUT_DIR): os.mkdir(OUTPUT_DIR) X = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='X') noise = tf.placeholder(dtype=tf.float32, shape=[None, z_dim], name='noise') is_training = tf.placeholder(dtype=tf.bool, name='is_training') def lrelu(x, leak=0.2): return tf.maximum(x, leak * x) def sigmoid_cross_entropy_with_logits(x, y): return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
判別器部分
def discriminator(image, reuse=None, is_training=is_training): momentum = 0.9 with tf.variable_scope('discriminator', reuse=reuse): h0 = lrelu(tf.layers.conv2d(image, kernel_size=5, filters=64, strides=2, padding='same')) h1 = tf.layers.conv2d(h0, kernel_size=5, filters=128, strides=2, padding='same') h1 = lrelu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum)) h2 = tf.layers.conv2d(h1, kernel_size=5, filters=256, strides=2, padding='same') h2 = lrelu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum)) h3 = tf.layers.conv2d(h2, kernel_size=5, filters=512, strides=2, padding='same') h3 = lrelu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum)) h4 = tf.contrib.layers.flatten(h3) h4 = tf.layers.dense(h4, units=1) return tf.nn.sigmoid(h4), h4
生成器部分
def generator(z, is_training=is_training): momentum = 0.9 with tf.variable_scope('generator', reuse=None): d = 3 h0 = tf.layers.dense(z, units=d * d * 512) h0 = tf.reshape(h0, shape=[-1, d, d, 512]) h0 = tf.nn.relu(tf.contrib.layers.batch_norm(h0, is_training=is_training, decay=momentum)) h1 = tf.layers.conv2d_transpose(h0, kernel_size=5, filters=256, strides=2, padding='same') h1 = tf.nn.relu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum)) h2 = tf.layers.conv2d_transpose(h1, kernel_size=5, filters=128, strides=2, padding='same') h2 = tf.nn.relu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum)) h3 = tf.layers.conv2d_transpose(h2, kernel_size=5, filters=64, strides=2, padding='same') h3 = tf.nn.relu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum)) h4 = tf.layers.conv2d_transpose(h3, kernel_size=5, filters=1, strides=1, padding='valid', activation=tf.nn.tanh, name='g') return h4
定義損失函數,注意這裏實現了兩個判別器,但參數是共享的
g = generator(noise) d_real, d_real_logits = discriminator(X) d_fake, d_fake_logits = discriminator(g, reuse=True) vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')] vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')] loss_d_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_real_logits, tf.ones_like(d_real))) loss_d_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits, tf.zeros_like(d_fake))) loss_g = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits, tf.ones_like(d_fake))) loss_d = loss_d_real + loss_d_fake
定義優化函數,注意損失函數須要和可調參數對應上
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer_d = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_d, var_list=vars_d) optimizer_g = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_g, var_list=vars_g)
定義一個輔助函數,用於將多張圖片以網格狀拼在一塊兒顯示
def montage(images): if isinstance(images, list): images = np.array(images) img_h = images.shape[1] img_w = images.shape[2] n_plots = int(np.ceil(np.sqrt(images.shape[0]))) m = np.ones((images.shape[1] * n_plots + n_plots + 1, images.shape[2] * n_plots + n_plots + 1)) * 0.5 for i in range(n_plots): for j in range(n_plots): this_filter = i * n_plots + j if this_filter < images.shape[0]: this_img = images[this_filter] m[1 + i + i * img_h:1 + i + (i + 1) * img_h, 1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img return m
開始訓練,每次迭代訓練G兩次
sess = tf.Session() sess.run(tf.global_variables_initializer()) z_samples = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32) samples = [] loss = {'d': [], 'g': []} for i in range(60000): n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32) batch = mnist.train.next_batch(batch_size=batch_size)[0] batch = np.reshape(batch, [-1, 28, 28, 1]) batch = (batch - 0.5) * 2 d_ls, g_ls = sess.run([loss_d, loss_g], feed_dict={X: batch, noise: n, is_training: True}) loss['d'].append(d_ls) loss['g'].append(g_ls) sess.run(optimizer_d, feed_dict={X: batch, noise: n, is_training: True}) sess.run(optimizer_g, feed_dict={X: batch, noise: n, is_training: True}) sess.run(optimizer_g, feed_dict={X: batch, noise: n, is_training: True}) if i % 1000 == 0: print(i, d_ls, g_ls) gen_imgs = sess.run(g, feed_dict={noise: z_samples, is_training: False}) gen_imgs = (gen_imgs + 1) / 2 imgs = [img[:, :, 0] for img in gen_imgs] gen_imgs = montage(imgs) plt.axis('off') plt.imshow(gen_imgs, cmap='gray') plt.savefig(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % i)) plt.show() samples.append(gen_imgs) plt.plot(loss['d'], label='Discriminator') plt.plot(loss['g'], label='Generator') plt.legend(loc='upper right') plt.savefig('Loss.png') plt.show() imageio.mimsave(os.path.join(OUTPUT_DIR, 'samples.gif'), samples, fps=5)
生成的圖片以下,因爲損失函數中並未使用到逐像素比較,所以圖形邊緣不會出現模糊
保存模型,便於後續使用
saver = tf.train.Saver() saver.save(sess, './mnist_dcgan', global_step=60000)
加載模型,若是須要的話,例如在單機上使用
# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import matplotlib.pyplot as plt batch_size = 100 z_dim = 100 def montage(images): if isinstance(images, list): images = np.array(images) img_h = images.shape[1] img_w = images.shape[2] n_plots = int(np.ceil(np.sqrt(images.shape[0]))) m = np.ones((images.shape[1] * n_plots + n_plots + 1, images.shape[2] * n_plots + n_plots + 1)) * 0.5 for i in range(n_plots): for j in range(n_plots): this_filter = i * n_plots + j if this_filter < images.shape[0]: this_img = images[this_filter] m[1 + i + i * img_h:1 + i + (i + 1) * img_h, 1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img return m sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.import_meta_graph('./mnist_dcgan-60000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) graph = tf.get_default_graph() g = graph.get_tensor_by_name('generator/g/Tanh:0') noise = graph.get_tensor_by_name('noise:0') is_training = graph.get_tensor_by_name('is_training:0') n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32) gen_imgs = sess.run(g, feed_dict={noise: n, is_training: False}) gen_imgs = (gen_imgs + 1) / 2 imgs = [img[:, :, 0] for img in gen_imgs] gen_imgs = montage(imgs) plt.axis('off') plt.imshow(gen_imgs, cmap='gray') plt.show()