概述:在前期的文章中,咱們用TensorFlow完成了對手寫數字的識別,獲得了94.09%的識別準確度,效果還算不錯。在這篇文章中,筆者將帶領你們用GAN模型,生成咱們想要的手寫數字。python
GAN簡介git
對抗性生成網絡(GenerativeAdversarial Network),由 Ian Goodfellow 首先提出,由兩個網絡組成,分別是generator網絡(用於生成)和discriminator網絡(用於判別)。GAN網絡的目的就是使其本身生成一副圖片,好比說通過對一系列貓的圖片的學習,generator網絡能夠本身「繪製」出一張貓的圖片,且儘可能真實。discriminator網絡則是用來進行判斷的,將一張真實的圖片和一張由generator網絡生成的照片同時交給discriminator網絡,不斷訓練discriminator網絡,使其能夠準確將discriminator網絡生成的「假圖片」找出來。就這樣,generator網絡不斷改進使其能夠騙過discriminator網絡,而discriminator網絡不斷改進使其能夠更準確找到「假圖片」,這種相互促進相互對抗的關係,就叫作對抗網絡。圖一中展現了GAN模型的結構。網絡
思路梳理app
將MNIST數據集中標籤爲0的圖片提取出來,而後訓練discriminator網絡,進行手寫數字0識別,接着讓generator產生一張隨機圖片,讓訓練好的discriminator去識別這張生成的圖片,不斷訓練discriminator,直到discriminator網絡將生成的圖片當作數字0爲止。dom
生成「假圖片」函數
生成一張隨機像素的28*28的圖片,分別進行全鏈接,Leaky ReLU函數激活,dropout處理(隨機丟棄一些神經元,防止過擬合),全鏈接,tanh函數激活,最終生成一張「假圖片」,TensorFlow代碼以下:學習
1def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
2 with tf.variable_scope("generator", reuse=reuse):
3 hidden1 = tf.layers.dense(noise_img, n_units) # 全鏈接層
4 hidden1 = tf.maximum(alpha * hidden1, hidden1)
5 hidden1 = tf.layers.dropout(hidden1, rate=0.2)
6 logits = tf.layers.dense(hidden1, out_dim)
7 outputs = tf.tanh(logits)
8 return logits, outputs
圖像判別url
將須要進行判別的圖片前後通過全鏈接,Leaky ReLU函數激活,全鏈接,sigmoid函數激活處理,最終輸出圖片的識別結果,TensorFlow代碼以下:spa
1def get_discriminator(img, n_units, reuse=False, alpha=0.01):
2 with tf.variable_scope("discriminator", reuse=reuse):
3 hidden1 = tf.layers.dense(img, n_units)
4 hidden1 = tf.maximum(alpha * hidden1, hidden1)
5 logits = tf.layers.dense(hidden1, 1)
6 outputs = tf.sigmoid(logits)
7 return logits, outputs
完整代碼3d
GAN手寫數字識別的完整代碼以下:
1import tensorflow as tf
2from tensorflow.examples.tutorials.mnist import input_data
3import matplotlib.pyplot as plt
4import numpy as np
5
6mnist = input_data.read_data_sets("E:/Tensor/MNIST_data/")
7img = mnist.train.images[50]
8
9
10def get_inputs(real_size, noise_size):
11 real_img = tf.placeholder(tf.float32, [None, real_size], name="real_img")
12 noise_img = tf.placeholder(tf.float32, [None, noise_size], name="noise_img")
13 return real_img, noise_img
14
15
16# 生成圖像
17def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
18 with tf.variable_scope("generator", reuse=reuse):
19 hidden1 = tf.layers.dense(noise_img, n_units) # 全鏈接層
20 hidden1 = tf.maximum(alpha * hidden1, hidden1)
21 hidden1 = tf.layers.dropout(hidden1, rate=0.2)
22 logits = tf.layers.dense(hidden1, out_dim)
23 outputs = tf.tanh(logits)
24 return logits, outputs
25
26
27# 圖像判別
28def get_discriminator(img, n_units, reuse=False, alpha=0.01):
29 with tf.variable_scope("discriminator", reuse=reuse):
30 hidden1 = tf.layers.dense(img, n_units)
31 hidden1 = tf.maximum(alpha * hidden1, hidden1)
32 logits = tf.layers.dense(hidden1, 1)
33 outputs = tf.sigmoid(logits)
34 return logits, outputs
35#真實圖像size
36img_size = mnist.train.images[0].shape[0]
37#傳入generator的噪聲size
38noise_size = 100
39#生成器隱層參數
40g_units = 128
41#判別器隱層參數
42d_units = 128
43#Leaky ReLU參數
44alpha = 0.01
45#學習率
46learning_rate = 0.001
47#label smoothing
48smooth = 0.1
49tf.reset_default_graph()
50real_img, noise_img = get_inputs(img_size, noise_size)
51g_logits, g_outputs = get_generator(noise_img, g_units, img_size)
52
53d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
54d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)
55
56d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
57 logits=d_logits_real, labels=tf.ones_like(d_logits_real)
58) * (1 - smooth))
59d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
60 logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)
61))
62d_loss = tf.add(d_loss_real, d_loss_fake)
63g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
64 logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)
65) * (1 - smooth))
66
67train_vars = tf.trainable_variables()
68g_vars = [var for var in train_vars if var.name.startswith("generator")]
69d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
70
71d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
72g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
73
74
75epochs = 10000
76samples = []
77n_sample = 10
78losses = []
79
80i = j = 0
81while i<10000:
82 if mnist.train.labels[j] == 0:
83 samples.append(mnist.train.images[j])
84 i += 1
85 j += 1
86
87print(len(samples))
88size = samples[0].size
89
90with tf.Session() as sess:
91 tf.global_variables_initializer().run()
92 for e in range(epochs):
93 batch_images = samples[e] * 2 -1
94 batch_noise = np.random.uniform(-1, 1, size=noise_size)
95
96 _ = sess.run(d_train_opt, feed_dict={real_img:[batch_images], noise_img:[batch_noise]})
97 _ = sess.run(g_train_opt, feed_dict={noise_img:[batch_noise]})
98
99 sample_noise = np.random.uniform(-1, 1, size=noise_size)
100 g_logit, g_output = sess.run(get_generator(noise_img, g_units, img_size,
101 reuse=True), feed_dict={
102 noise_img:[sample_noise]
103 })
104 print(g_logit.size)
105 g_output = (g_output+1)/2
106 plt.imshow(g_output.reshape([28, 28]), cmap='Greys_r')
107 plt.show()
訓練效果
在通過了10000次的迭代後,generator網絡生成的圖片已經接近手寫數字零的形狀。
本文是對GAN模型的初次探索,在後續GAN模型的系列文章中,筆者將層層深刻的去講解GAN模型複雜的應用。