谷歌開源的 GAN 庫--TFGAN

本文大約 8000 字,閱讀大約須要 12 分鐘python

第一次翻譯,限於英語水平,可能很多地方翻譯不許確,請見諒!git

最近谷歌開源了一個基於 TensorFlow 的庫--TFGAN,方便開發者快速上手 GAN 的訓練,其 Github 地址以下:github

github.com/tensorflow/…算法

原文網址:Generative Adversarial Networks: Google open sources TensorFlow-GAN (TFGAN)bash


若是你玩過波斯王子,那你應該知道你須要保護本身不被」影子「所殺掉,但這也是一個矛盾:若是你殺死「影子」,那遊戲就結束了;但你不作任何事情,那麼遊戲也會輸掉。微信

儘管生成對抗網絡(GAN)有很多優勢,但它也面臨着類似的區分問題。大部分支持 GAN 的深度學習專業也是很是謹慎的支持它,並指出它確實存在穩定性的問題。網絡

GAN 的這個問題也能夠稱作總體收斂性問題。儘管判別器 D 和 生成器 D 相互競爭博弈,但同時也相互依賴對方來達到有效的訓練。若是其中一方訓練得不好,那整個系統也會不好(這也是以前提到的梯度消失或者模式奔潰問題)。而且你也須要確保他們不會訓練太過分,形成另外一方沒法訓練了。所以,波斯王子是一個頗有趣的概念。框架

首先,神經網絡的提出就是爲了模仿人類的大腦(儘管是人爲的)。它們也已經在物體識別和天然語言處理方面取得成功。可是,想要在思考和行爲上與人類一致,這還有很是大的差距。dom

那麼是什麼讓 GANs 成爲機器學習領域一個熱門話題呢?由於它不只只是一個相對新的結構,它更加是一個比以前其餘模型都能更加準確的對真實數據建模,能夠說是深度學習的一個革命性的變化。機器學習

最後,它是一個同時訓練兩個獨立的網絡的新模型,這兩個網絡分別是判別器和生成器。這樣一個非監督神經網絡卻能比其餘傳統網絡獲得更好性能的結果。

但目前事實是咱們對 GANs 的研究還只是很是淺層,仍然有着不少挑戰須要解決。GANs 目前也存在很多問題,好比沒法區分在某個位置應該有多少特定的物體,不能應用到 3D 物體,以及也不能理解真實世界的總體結構。固然如今有大量研究正在研究如何解決上述問題,新的模型也取得更好的性能。

而最近谷歌爲了讓 GANs 更容易實現,設計開發並開源了一個基於 TensorFlow 的輕量級庫--TFGAN。

根據谷歌的介紹,TFGAN 提供了一個基礎結構來減小訓練一個 GAN 模型的難度,同時提供很是好測試的損失函數和評估標準,以及給出容易上手的例子,這些例子強調了 TFGAN 的靈活性和易於表現的優勢。

此外,還提供了一個教程,包含一個高級的 API 能夠快速使用本身的數據集訓練一個模型。

來源: research.googleblog.com

上圖是展現了對抗損失在圖像壓縮方面的效果。最上方第一行圖片是來自 ImageNet 數據集的圖片,也是原始輸入圖片,中間第二行展現了採用傳統損失函數訓練獲得的圖像壓縮神經網絡的壓縮和解壓縮效果,最底下一行則是結合傳統損失函數和對抗損失函數訓練的網絡的結果,能夠看到儘管基於對抗損失的圖片並不像原始圖片,可是它比第二行的網絡獲得更加清晰和細節更好的圖片。

TFGAN 既提供了幾行代碼就能夠實現的簡答函數來調用大部分 GAN 的使用例子,也是創建在包含複雜 GAN 設計的模式化方式。這就是說,咱們能夠採用本身須要的模塊,好比損失函數、評估策略、特徵以及訓練等等,這些都是獨立的模塊。TFGAN 這樣的設計方式其實就知足了不一樣使用者的需求,對於入門新手能夠快速訓練一個模型來看看效果,對於須要修改其中任何一個模塊的使用者也能修改對應模塊,而不會牽一髮而動全身。

最重要的是,谷歌也保證了這個代碼是通過測試的,不須要擔憂通常的 GAN 庫形成的數字或者統計失誤。

開始使用

首先添加如下代碼來導入 tensorflow 和 聲明一個 TFGAN 的實例:

import tensorflow as tf
tfgan = tf.contrib.gan
複製代碼

爲什麼使用 TFGAN

  • 採用良好測試而且很靈活的調用接口實現快速訓練生成器和判別器網絡,此外,還能夠混合 TFGAN、原生 TensorFlow以及其餘自定義框架代碼;
  • 使用實現好的GAN 的損失函數和懲罰策略 (好比 Wasserstein loss、梯度懲罰等)
  • 訓練階段對 GAN 進行監控和可視化操做,以及評估生成結果
  • 使用實現好的技巧來穩定和提升性能
  • 基於常規的 GAN 訓練例子來開發
  • 採用GANEstimator接口裏快速訓練一個 GAN 模型
  • TFGAN 的結構改進也會自動提高你的 TFGAN 項目的性能
  • TFGAN 會不斷添加最新研究的算法成果

TFGAN 的部件有哪些呢?

TFGAN 是由多個設計爲獨立的部件組成的,分別是:

  • core:提供了一個主要的訓練 GAN 模型的結構。訓練過程分爲四個階段,每一個階段均可以採用自定義代碼或者 調用 TFGAN 庫接口來完成;
  • features:包含許多常見的 GAN 運算和正則化技術,好比實例正則化(instance normalization)
  • losses:包含常見的 GAN 的損失函數和懲罰機制,好比 Wasserstein loss、梯度懲罰、相互信息懲罰等
  • evaulation:使用一個預訓練好的 Inception 網絡來利用Inception Score或者Frechet Distance評估標準來評估非條件生成模型。固然也支持利用本身訓練的分類器或者其餘方法對有條件生成模型的評估
  • examples and tutorial:使用 TFGAN 訓練 GAN 模型的例子和教程。包含了使用非條件和條件式的 GANs 模型,好比 InfoGANs 等。

訓練一個 GAN 模型

典型的 GAN 模型訓練步驟以下:

  1. 爲你的網絡指定輸入,好比隨機噪聲,或者是輸入圖片(通常是應用在圖片轉換的應用,好比 pix2pixGAN 模型)
  2. 採用GANModel接口定義生成器和判別器網絡
  3. 採用GANLoss指定使用的損失函數
  4. 採用GANTrainOps設置訓練運算操做,即優化器
  5. 開始訓練

固然,GAN 的設置有多種形式。好比,你能夠在非條件下訓練生成器生成圖片,或者能夠給定一些條件,好比類別標籤等輸入到生成器中來訓練。不管是哪一種設置,TFGAN 都有相應的實現。下面將結合代碼例子來進一步介紹。

實例

非條件 MNIST 圖片生成

第一個例子是訓練一個生成器來生成手寫數字圖片,即 MNIST 數據集。生成器的輸入是從多變量均勻分佈採樣獲得的隨機噪聲,目標輸出是 MNIST 的數字圖片。具體查看論文「Generative Adversarial Networks」。代碼以下:

# 配置輸入
# 真實數據來自 MNIST 數據集
images = mnist_data_provider.provide_data(FLAGS.batch_size)
# 生成器的輸入,從多變量均勻分佈採樣獲得的隨機噪聲
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

# 調用 tfgan.gan_model() 函數定義生成器和判別器網絡模型
gan_model = tfgan.gan_model(
    generator_fn=mnist.unconditional_generator,  
    discriminator_fn=mnist.unconditional_discriminator,  
    real_data=images,
    generator_inputs=noise)

# 調用 tfgan.gan_loss() 定義損失函數
gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss)

# 調用 tfgan.gan_train_ops() 指定生成器和判別器的優化器
train_ops = tfgan.gan_train_ops(
    gan_model,
    gan_loss,
    generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
    discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))

# tfgan.gan_train() 開始訓練,並指定訓練迭代次數 num_steps
tfgan.gan_train(
    train_ops,
    hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
    logdir=FLAGS.train_log_dir)
複製代碼
條件式 MNIST 圖片生成

第二個例子一樣仍是生成 MNIST 圖片,可是此次輸入到生成器的不只僅是隨機噪聲,還會給類別標籤,這種 GAN 模型也被稱做條件 GAN,其目的也是爲了讓 GAN 訓練不會太過自由。具體能夠看論文「Conditional Generative Adversarial Nets」

代碼方面,僅僅須要修改輸入和創建生成器與判別器模型部分,以下所示:

# 配置輸入
# 真實數據來自 MNIST 數據集,這裏增長了類別標籤--one_hot_labels
images, one_hot_labels = mnist_data_provider.provide_data(FLAGS.batch_size)
# 生成器的輸入,從多變量均勻分佈採樣獲得的隨機噪聲
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

# 調用 tfgan.gan_model() 函數定義生成器和判別器網絡模型
gan_model = tfgan.gan_model(
    generator_fn=mnist.conditional_generator,  
    discriminator_fn=mnist.conditional_discriminator,  
    real_data=images,
    generator_inputs=(noise, one_hot_labels)) # 生成器的輸入增長了類別標籤
    
# 剩餘的代碼保持一致
...
複製代碼
對抗損失

第三個例子結合了 L1 pixel loss 和對抗損失來學習自動編碼圖片。瓶頸層能夠用來傳輸圖片的壓縮表示。若是僅僅使用 pixel-wise loss,網絡只回傾向於生成模糊的圖片,但 GAN 能夠用來讓這個圖片重建過程更加逼真。具體能夠看論文「Full Resolution Image Compression with Recurrent Neural Networks」來了解如何用 GAN 來實現圖像壓縮,以及論文「Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network」瞭解如何用 GANs 來加強生成的圖片質量。

代碼以下:

# 配置輸入
images = image_provider.provide_data(FLAGS.batch_size)

# 配置生成器和判別器網絡
gan_model = tfgan.gan_model(
    generator_fn=nets.autoencoder,  # 自定義的 autoencoder
    discriminator_fn=nets.discriminator,  # 自定義的 discriminator
    real_data=images,
    generator_inputs=images)

# 創建 GAN loss 和 pixel loss
gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
    gradient_penalty=1.0)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)

# 結合兩個 loss
gan_loss = tfgan.losses.combine_adversarial_loss(
    gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)

# 剩下代碼保持一致
...
複製代碼
圖像轉換

第四個例子是圖像轉換,它是將一個領域的圖片轉變成另外一個領域的一樣大小的圖片。好比將語義分割圖變成街景圖,或者是灰度圖變成彩色圖。具體細節看論文「Image-to-Image Translation with Conditional Adversarial Networks」

代碼以下:

# 配置輸入,注意增長了 target_image
input_image, target_image = data_provider.provide_data(FLAGS.batch_size)

# 配置生成器和判別器網絡
gan_model = tfgan.gan_model(
    generator_fn=nets.generator,  
    discriminator_fn=nets.discriminator,  
    real_data=target_image,
    generator_inputs=input_image)

# 創建 GAN loss 和 pixel loss
gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan_losses.least_squares_generator_loss,
    discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)

# 結合兩個 loss
gan_loss = tfgan.losses.combine_adversarial_loss(
    gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)

# 剩下代碼保持一致
...
複製代碼
InfoGAN

最後一個例子是採用 InfoGAN 模型來生成 MNIST 圖片,可是能夠不須要任何標籤來控制生成的數字類型。具體細節能夠看論文「InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets」

代碼以下:

# 配置輸入
images = mnist_data_provider.provide_data(FLAGS.batch_size)

# 配置生成器和判別器網絡
gan_model = tfgan.infogan_model(
    generator_fn=mnist.infogan_generator,  
    discriminator_fn=mnist.infogran_discriminator,  
    real_data=images,
    unstructured_generator_inputs=unstructured_inputs,  # 自定義輸入
    structured_generator_inputs=structured_inputs)  # 自定義

# 配置 GAN loss 以及相互信息懲罰
gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
    gradient_penalty=1.0,
    mutual_information_penalty_weight=1.0)

# 剩下代碼保持一致
...
複製代碼
自定義模型的建立

最後一樣是非條件 GAN 生成 MNIST 圖片,但利用GANModel函數來配置更多參數從而更加精確控制模型的建立。

代碼以下:

# 配置輸入
images = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

# 手動定義生成器和判別器模型
with tf.variable_scope('Generator') as gen_scope:
  generated_images = generator_fn(noise)
with tf.variable_scope('Discriminator') as dis_scope:
  discriminator_gen_outputs = discriminator_fn(generated_images)
with variable_scope.variable_scope(dis_scope, reuse=True):
  discriminator_real_outputs = discriminator_fn(images)
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(dis_scope)

# 依賴於你須要使用的 TFGAN 特徵,你並不須要指定 `GANModel`函數的每一個參數,不過
# 最少也須要指定判別器的輸出和變量
gan_model = tfgan.GANModel(
    generator_inputs,
    generated_data,
    generator_variables,
    gen_scope,
    generator_fn,
    real_data,
    discriminator_real_outputs,
    discriminator_gen_outputs,
    discriminator_variables,
    dis_scope,
    discriminator_fn)

# 剩下代碼和第一個例子同樣
...
複製代碼

最後,再次給出 TFGAN 的 Github 地址以下:

github.com/tensorflow/…


若是有翻譯不當的地方或者有任何建議和見解,歡迎留言交流;也歡迎關注個人微信公衆號--機器學習與計算機視覺或者掃描下方的二維碼,和我分享你的建議和見解,指正文章中可能存在的錯誤,你們一塊兒交流,學習和進步!

相關文章
相關標籤/搜索