【機器學習PAI實戰】—— 玩轉人工智能之利用GAN自動生成二次元頭像

前言

深度學習做爲人工智能的重要手段,迎來了爆發,在NLP、CV、物聯網、無人機等多個領域都發揮了很是重要的做用。最近幾年,各類深度學習算法層出不窮, Generative Adverarial Network(GAN)自2014年提出以來,引發普遍關注,身爲深度學習三巨頭之一的Yan Lecun對GAN的評價頗高,認爲GAN是近年來在深度學習上最大的突破,是近十年來機器學習上最有意思的工做。圍繞GAN的論文數量也迅速增多,各類版本的GAN出現,主要在CV領域帶來了一些貢獻,以下圖所示。python

咱們能夠利用GAN生成一些咱們須要的圖像或者文本,好比二次元頭像。git

GAN簡介

GAN主要的應用是自動生成一些東西,包括圖像和文本等,好比隨機給一個向量做爲輸入,經過GAN的Generator生成一張圖片,或者生成一串語句。Conditional GAN的應用更多一些,好比數據集是一段文字和圖像的數據對,經過訓練,GAN能夠經過給定一段文字生成對應的圖像。github

GAN主要能夠分爲Generator(生成器)和Discriminator(判別器)兩個部分,其中Generator其實就是一個神經網絡,輸入一個向量,能夠輸出一張圖像(即一個高維的向量表示),以下圖示。算法

Discriminator也是一個神經網絡,輸入爲一張圖像,輸出爲一個數值,輸出的數值用於判斷輸入的圖像是不是真的,數值越大,說明圖像是真的,數值越小,說明圖像爲假的,以下圖示。網絡

Generator負責生成圖像,Discriminator負責對Generator生成的圖像和真實圖像去進行對比,區別出真假,Generator須要不斷優化來欺騙Discriminator,以假亂真;而Discriminator也不斷優化,來提升識別能力,可以識別出Generator的把戲。兩者的這種關係能夠形象地經過下圖展現。框架

Generator和Discriminator鏈接起來,造成一個比較大的深層網絡,即爲GAN網絡。機器學習

場景描述

深度學習的各類算法在PAI上能夠經過PAI-DSW進行實現,在PAI-DSW上進行訓練數據,利用GAN自動生成二次元頭像。學習

數據準備

首先須要準備真實的二次元頭像做爲數據集,這裏從網上找到一些共享的資源,存儲在了釘釘釘盤中,釘盤地址 ,提取密碼: c2pz,數據集以下圖示,約5萬多張:大數據

算法實踐

利用PAI-DSW進行GAN算法實踐,首先須要安裝準備好環境。優化

首先進入到Notebook建模,建立新實例,以後打開實例,進入Terminal,在Terminal下用戶能夠像在本身本地同樣安裝相應的依賴包,進行操做。

準備好環境以後,咱們能夠經過以下圖示方法,將基於Tensorflow的DCGAN代碼和數據集上傳上去。

用於訓練的DCGAN代碼地址:https://github.com/carpedm20/DCGAN-tensorflow,關於DCGAN的網絡框架圖以下,詳細介紹能夠參考論文:https://arxiv.org/abs/1511.06434,這裏咱們不作詳述。

數據集和代碼上傳成功,以下圖示。

其中,data目錄下的faces即爲數據集,該文件夾下爲對應的5萬多張真實二次元頭像。DCGAN-tensorflow爲整個代碼路徑,其中最主要的兩個代碼文件是main.py和model.py,其中最主要的核心代碼以下。

 

def main(_):
  pp.pprint(flags.FLAGS.__flags)

  if FLAGS.input_width is None:
    FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = FLAGS.output_height

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir)
    else:
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir)

    show_all_variables()

    if FLAGS.train:
      dcgan.train(FLAGS)

 

else:
          # Update D network
          _, summary_str = self.sess.run([d_optim, self.d_sum],
            feed_dict={ self.inputs: batch_images, self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Update G network
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)
          
          errD_fake = self.d_loss_fake.eval({ self.z: batch_z })
          errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
          errG = self.g_loss.eval({self.z: batch_z})

一切就緒以後,咱們執行命令進行訓練,調用命令以下:

​python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset faces --crop --train --epoch 300 --input_fname_pattern "*.jpg"

其中,參數dateset指定數據集的目錄,epoch指定循環迭代的次數,input_height、input_width用於指定輸入文件的大小,輸出文件的大小一樣也須要參數設定,代碼執行過程以下圖示:​

                            

咱們來看下執行結果,分別看一下epoch爲1,30,100的時候生成的二次元頭像效果圖。

epoch=1

epoch=30

epoch=100​

咱們發現,隨着不斷迭代,生成的二次元頭像也愈來愈逼真。

總結

經過上面的實踐,咱們領略到了GAN的魅力,GAN的變種有不少,除此以外咱們還能夠利用GAN作很是多的有意思的事情,好比經過文字生成圖像,經過簡單文字生成宣傳海報等。PAI-DSW像是一個練武場,爲咱們準備好了深度學習所須要的環境和條件,讓咱們能夠盡情享受大數據和深度學習的樂趣,除了GAN,像比較火熱的Bert等模型,咱們也均可以試一試。

原文連接 本文爲雲棲社區原創內容,未經容許不得轉載。

相關文章
相關標籤/搜索