GAN模型生成手寫字

概述:在前期的文章中,咱們用TensorFlow完成了對手寫數字的識別,獲得了94.09%的識別準確度,效果還算不錯。在這篇文章中,筆者將帶領你們用GAN模型,生成咱們想要的手寫數字。python

GAN簡介git

對抗性生成網絡(GenerativeAdversarial Network),由 Ian Goodfellow 首先提出,由兩個網絡組成,分別是generator網絡(用於生成)和discriminator網絡(用於判別)。GAN網絡的目的就是使其本身生成一副圖片,好比說通過對一系列貓的圖片的學習,generator網絡能夠本身「繪製」出一張貓的圖片,且儘可能真實。discriminator網絡則是用來進行判斷的,將一張真實的圖片和一張由generator網絡生成的照片同時交給discriminator網絡,不斷訓練discriminator網絡,使其能夠準確將discriminator網絡生成的「假圖片」找出來。就這樣,generator網絡不斷改進使其能夠騙過discriminator網絡,而discriminator網絡不斷改進使其能夠更準確找到「假圖片」,這種相互促進相互對抗的關係,就叫作對抗網絡。圖一中展現了GAN模型的結構。網絡

思路梳理app

將MNIST數據集中標籤爲0的圖片提取出來,而後訓練discriminator網絡,進行手寫數字0識別,接着讓generator產生一張隨機圖片,讓訓練好的discriminator去識別這張生成的圖片,不斷訓練discriminator,直到discriminator網絡將生成的圖片當作數字0爲止。dom

生成「假圖片函數

生成一張隨機像素的28*28的圖片,分別進行全鏈接,Leaky ReLU函數激活,dropout處理(隨機丟棄一些神經元,防止過擬合),全鏈接,tanh函數激活,最終生成一張「假圖片」,TensorFlow代碼以下:學習

 

1def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
2    with tf.variable_scope("generator", reuse=reuse):
3        hidden1 = tf.layers.dense(noise_img, n_units)  # 全鏈接層
4        hidden1 = tf.maximum(alpha * hidden1, hidden1)
5        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
6        logits = tf.layers.dense(hidden1, out_dim)
7        outputs = tf.tanh(logits)
8        return logits, outputs

圖像判別url

將須要進行判別的圖片前後通過全鏈接,Leaky ReLU函數激活,全鏈接,sigmoid函數激活處理,最終輸出圖片的識別結果,TensorFlow代碼以下:spa

1def get_discriminator(img, n_units, reuse=False, alpha=0.01):
2    with tf.variable_scope("discriminator", reuse=reuse):
3        hidden1 = tf.layers.dense(img, n_units)
4        hidden1 = tf.maximum(alpha * hidden1, hidden1)
5        logits = tf.layers.dense(hidden1, 1)
6        outputs = tf.sigmoid(logits)
7        return logits, outputs

完整代碼3d

GAN手寫數字識別的完整代碼以下:

  1import tensorflow as tf
 2from tensorflow.examples.tutorials.mnist import input_data
 3import matplotlib.pyplot as plt
 4import numpy as np
 5
 6mnist = input_data.read_data_sets("E:/Tensor/MNIST_data/")
 7img = mnist.train.images[50]
 8
 9
10def get_inputs(real_size, noise_size):
11    real_img = tf.placeholder(tf.float32, [None, real_size], name="real_img")
12    noise_img = tf.placeholder(tf.float32, [None, noise_size], name="noise_img")
13    return real_img, noise_img
14
15
16# 生成圖像
17def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
18    with tf.variable_scope("generator", reuse=reuse):
19        hidden1 = tf.layers.dense(noise_img, n_units)  # 全鏈接層
20        hidden1 = tf.maximum(alpha * hidden1, hidden1)
21        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
22        logits = tf.layers.dense(hidden1, out_dim)
23        outputs = tf.tanh(logits)
24        return logits, outputs
25
26
27# 圖像判別
28def get_discriminator(img, n_units, reuse=False, alpha=0.01):
29    with tf.variable_scope("discriminator", reuse=reuse):
30        hidden1 = tf.layers.dense(img, n_units)
31        hidden1 = tf.maximum(alpha * hidden1, hidden1)
32        logits = tf.layers.dense(hidden1, 1)
33        outputs = tf.sigmoid(logits)
34        return logits, outputs
35#真實圖像size
36img_size = mnist.train.images[0].shape[0]
37#傳入generator的噪聲size
38noise_size = 100
39#生成器隱層參數
40g_units = 128
41#判別器隱層參數
42d_units = 128
43#Leaky ReLU參數
44alpha = 0.01
45#學習率
46learning_rate = 0.001
47#label smoothing
48smooth = 0.1
49tf.reset_default_graph()
50real_img, noise_img = get_inputs(img_size, noise_size)
51g_logits, g_outputs = get_generator(noise_img, g_units, img_size)
52
53d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
54d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)
55
56d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
57    logits=d_logits_real, labels=tf.ones_like(d_logits_real)
58) * (1 - smooth))
59d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
60    logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)
61))
62d_loss = tf.add(d_loss_real, d_loss_fake)
63g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
64    logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)
65) * (1 - smooth))
66
67train_vars = tf.trainable_variables()
68g_vars = [var for var in train_vars if var.name.startswith("generator")]
69d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
70
71d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
72g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
73
74
75epochs = 10000
76samples = []
77n_sample = 10
78losses = []
79
80i = j = 0
81while i<10000:
82    if mnist.train.labels[j] == 0:
83        samples.append(mnist.train.images[j])
84        i += 1
85    j += 1
86
87print(len(samples))
88size = samples[0].size
89
90with tf.Session() as sess:
91    tf.global_variables_initializer().run()
92    for e in range(epochs):
93        batch_images = samples[e] * -1
94        batch_noise = np.random.uniform(-1, 1, size=noise_size)
95
96        _ = sess.run(d_train_opt, feed_dict={real_img:[batch_images], noise_img:[batch_noise]})
97        _ = sess.run(g_train_opt, feed_dict={noise_img:[batch_noise]})
98
99    sample_noise = np.random.uniform(-1, 1, size=noise_size)
100    g_logit, g_output = sess.run(get_generator(noise_img, g_units, img_size,
101                                         reuse=True), feed_dict={
102        noise_img:[sample_noise]
103    })
104    print(g_logit.size)
105    g_output = (g_output+1)/2
106    plt.imshow(g_output.reshape([28, 28]), cmap='Greys_r')
107    plt.show()

 

訓練效果

在通過了10000次的迭代後,generator網絡生成的圖片已經接近手寫數字零的形狀。

  

  本文是對GAN模型的初次探索,在後續GAN模型的系列文章中,筆者將層層深刻的去講解GAN模型複雜的應用。

 

相關文章
相關標籤/搜索