實戰生成對抗網絡[3]:DCGAN

在上一篇文章《實戰生成對抗網絡[2]:生成手寫數字》中,咱們使用了簡單的神經網絡來生成手寫數字,能夠看出手寫數字字形,但不夠完美,生成的手寫數字有些毛糙,邊緣不夠平滑。git

生成對抗網絡中,生成器和判別器是一對冤家。要提升生成器的水平,就要提升判別器的識別能力。在《一步步提升手寫數字的識別率(3)》系列文章中,咱們探討了如何提升手寫數字的識別率,發現卷積神經網絡在圖像處理方面優點巨大,最後採用卷積神經網絡模型,達到一個不錯的識別率。天然的,爲了提升生成對抗網絡的手寫數字生成質量,咱們是否也能夠採用卷積神經網絡呢?github

答案是確定的,不過和《一步步提升手寫數字的識別率(3)》中隨便採用一個卷積神經網絡結構是不夠的,由於生成對抗網絡中,有兩個神經網絡模型互相對抗,隨便選擇網絡結構,容易在迭代過程當中引發振盪,難以收斂。web

好在有專家學者進行了這方面的研究,下面就介紹一篇由Alec Radford、Luke Metz和Soumith Chintala合做完成的論文 arXiv: 1511.06434, 《利用深度卷積生成對抗網絡進行無監督表徵學習(Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks)》。bash

論文給出了生成器的模型結構,以下圖所示:網絡

從圖中能夠看,該網絡採用100x1噪聲向量(隨機輸入),表示爲z,並將其映射到G(Z)輸出,即64x64x3,其變換過程爲:架構

100x1 → 1024x4x4 → 512x8x8 → 256x16x16 → 128x32x32 → 64x64x3post

若是採用keras實現上述模型,很是簡單。不過須要注意的是,在本文中探討的手寫數字生成,其最終輸出是28 x 28 x 1的灰度圖片,因此咱們沿襲上面的模型架構,但在具體實現上作一些調整:學習

100x1 → 1024x1 → 128x7x7 → 128x14x14 → 14x14x64 → 28x28x64 → 8x28x1ui

代碼以下:spa

def generator_model():
  model = Sequential()
  model.add(Dense(input_dim=100, output_dim=1024))
  model.add(Activation('tanh'))
  model.add(Dense(128 * 7 * 7))
  model.add(BatchNormalization())
  model.add(Activation('tanh'))
  model.add(Reshape((7, 7, 128), input_shape=(128 * 7 * 7,)))
  model.add(UpSampling2D(size=(2, 2)))
  model.add(Conv2D(64, (5, 5), padding='same'))
  model.add(Activation('tanh'))
  model.add(UpSampling2D(size=(2, 2)))
  model.add(Conv2D(1, (5, 5), padding='same'))
  model.add(Activation('tanh'))
  return model
複製代碼

代碼中引入了批量規則化(BatchNormalization),在實踐中被證明能夠在許多場合提高訓練速度,減小初始化不佳帶來的問題而且一般能產生準確的結果。上採樣則是用來擴大維度。

判別器的實現差很少是將上述生成器模型倒過來實現,但使用最大池化代替了上採樣,代碼以下:

def discriminator_model():
  model = Sequential()
  model.add(
      Conv2D(64, (5, 5),
             padding='same',
             input_shape=(28, 28, 1))
  )
  model.add(Activation('tanh'))
  model.add(MaxPooling2D(pool_size=(2, 2)))
  model.add(Conv2D(128, (5, 5)))
  model.add(Activation('tanh'))
  model.add(MaxPooling2D(pool_size=(2, 2)))
  model.add(Flatten())
  model.add(Dense(1024))
  model.add(Activation('tanh'))
  model.add(Dense(1))
  model.add(Activation('sigmoid'))
  return model
複製代碼

在論文中,做者建議經過下面一些架構性的約束來固化網絡:

  • 在判別器中使用跨步卷積取代池化層,在生成器中使用反捲積取代池化層。
  • 在生成器和判別器中使用批量規則化。
  • 消除架構中較深的全鏈接層。
  • 在生成器的輸出層使用Tanh,在其餘層均使用ReLU激活。
  • 在判別器的全部層中都使用LeakyReLU激活。

上述代碼並無徹底遵照做者的建議,可見在面對不一樣的場景,開發者能夠有本身的發揮。事實上,在GANs in Action這本書中,做者也給出了手寫數字生成的另一種DCGAN模型,代碼可參考:github.com/GANs-in-Act…

通過100個epoch的迭代,咱們的代碼生成的手寫數字以下圖所示,雖然有些數字生成得不太準確,不過相對於上一篇文章的輸出,邊緣仍是要平滑一些,效果也有所改進:

本文所演示內容的完整代碼,請參考:github.com/mogoweb/aie…

image
相關文章
相關標籤/搜索