數據集難找?GAN生成你想要的數據!!!


GAN生成對抗網絡學習筆記python



1.GAN誕生背後的故事nginx


GAN創始人 Ian Goodfellow 在酒吧微醉後與同事討論學術問題,當時靈光乍現提出了GAN初步的想法,不過當時並無獲得同事的承認,在從酒吧回去後發現女友已經睡了,因而本身熬夜寫了代碼,發現還真有效果,因而通過一番研究後,GAN就誕生了,一篇開山之做。論文《Generative Adversarial Nets》首次提出GAN。git

論文連接:https://arxiv.org/abs/1406.2661web



2.GAN的原理:微信


GAN的主要靈感來源於博弈論中零和博弈的思想,應用到深度學習神經網絡上來講,就是經過生成網絡G(Generator)和判別網絡D(Discriminator)不斷博弈,進而使G學習到數據的分佈,若是用到圖片生成上,則訓練完成後,G能夠從一段隨機數中生成逼真的圖像。G, D的主要功能是:
網絡

  •   G是一個生成式的網絡,它接收一個隨機的噪聲z(隨機數),經過這個噪聲生成圖像 app

  • D是一個判別網絡,判別一張圖片是否是「真實的」。它的輸入參數是x,x表明一張圖片,輸出D(x)表明x爲真實圖片的機率,若是爲1,就表明100%是真實的圖片,而輸出爲0,就表明不多是真實的圖片dom

訓練過程當中,生成網絡G的目標就是儘可能生成真實的圖片去欺騙判別網絡D。而D的目標就是儘可能辨別出G生成的假圖像和真實的圖像。這樣,G和D構成了一個動態的「博弈過程」,最終的平衡點即納什均衡點.函數



通俗意思就是在犯罪分子造假幣和警察識別假幣的過程當中            工具

     [1]生成模型G至關於製造假幣的一方,其目的是根據看到的錢幣狀況和警察的識別技術,去儘可能生成更加真實的、警察識別不出的假幣。           

     [2]判別模型D至關於識別假幣的一方,其目的是儘量的識別出犯罪分子製造的假幣。這樣經過造假者和識假者雙方的較量和朝目的的改進,使得最後能達到生成模型能儘量真的錢幣、識假者判斷不出真假的納什均衡效果(真假幣機率都爲0.5)。



如圖所示:



3.GAN的原理圖:



4.GAN的特色:

  1.  相比較傳統的模型,他存在兩個不一樣的網絡,而不是單一的網絡,而且訓練方式採用的是對抗訓練方式

  2. GAN中G的梯度更新信息來自判別器D,而不是來自數據樣本



5.GAN 的優勢:

  1. GAN是一種生成式模型,相比較其餘生成模型(玻爾茲曼機和GSNs)只用到了反向傳播,而不須要複雜的馬爾科夫鏈

  2. 相比其餘全部模型, GAN能夠產生更加清晰,真實的樣本

  3. GAN採用的是一種無監督的學習方式訓練,能夠被普遍用在無監督學習和半監督學習領域

  4. 相比於變分自編碼器, GANs沒有引入任何決定性偏置( deterministic bias),變分方法引入決定性偏置,由於他們優化對數似然的下界,而不是似然度自己,這看起來致使了VAEs生成的實例比GANs更模糊

  5. 相比VAE, GANs沒有變分下界,若是鑑別器訓練良好,那麼生成器能夠完美的學習到訓練樣本的分佈.換句話說,GANs是漸進一致的,可是VAE是有誤差的

  6. GAN應用到一些場景上,好比圖片風格遷移,超分辨率,圖像補全,去噪,避免了損失函數設計的困難,無論三七二十一,只要有一個的基準,直接上判別器,剩下的就交給對抗訓練了。



6.GAN的缺點:

  1. 訓練GAN須要達到納什均衡,有時候能夠用梯度降低法作到,有時候作不到.咱們尚未找到很好的達到納什均衡的方法,因此訓練GAN相比VAE或者PixelRNN是不穩定的,但我認爲在實踐中它仍是比訓練玻爾茲曼機穩定的多

  2. GAN不適合處理離散形式的數據,好比文本

  3. GAN存在訓練不穩定、梯度消失、模式崩潰的問題(目前已解決)



7.訓練GAN的一些技巧:

  1.  輸入規範化到(-1,1)之間,最後一層的激活函數使用tanh(BEGAN除外)

  2.  使用wassertein GAN的損失函數,

  3. 若是有標籤數據的話,儘可能使用標籤,也有人提出使用反轉標籤效果很好,另外使用標籤平滑,單邊標籤平滑或者雙邊標籤平滑

  4.  使用mini-batch norm, 若是不用batch norm 可使用instance norm 或者weight norm

  5. 避免使用RELU和pooling層,減小稀疏梯度的可能性,可使用leakrelu激活函數

  6. 優化器儘可能選擇ADAM,學習率不要設置太大,初始1e-4能夠參考,另外能夠隨着訓練進行不斷縮小學習率,

  7. 給D的網絡層增長高斯噪聲,至關因而一種正則。



8.GAN的延伸有哪些:

DCGANCGANACGANinfoGANWGANSSGANPix2Pix GANCycle  GAN


9.GAN能夠作什麼答案是生成數據

生成音頻生成圖片(動物:貓,狗等;人臉圖片,人臉圖轉動漫圖等).......


先來個美食圖緩一緩(學累就先吃一點東西,哈哈哈)

繼續!!!!!


10.GAN的經典案例:生成手寫數字圖片

  • 源碼和數據集獲取方式在下方

  • 有py格式和ipynb格式兩種(代碼是同樣的)

代碼以下:

# -*- coding: utf-8 -*-"""Created on 2020-10-31
@author: 李運辰"""#導入數據包import tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import layersimport matplotlib.pyplot as plt#get_ipython().run_line_magic('matplotlib', 'inline')import numpy as npimport globimport os
# # 輸入(train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()
train_images = train_images.astype('float32')
# # 數據預處理train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
#歸一化 到【-1,1】train_images = (train_images -127.5)/127.5
BTATH_SIZE=256BUFFER_SIZE=60000
#輸入管道datasets = tf.data.Dataset.from_tensor_slices(train_images)
#打亂亂序,並取btath_sizedatasets = datasets.shuffle(BUFFER_SIZE).batch(BTATH_SIZE)
# # 生成器模型def generator_model(): model = tf.keras.Sequential() model.add(layers.Dense(256,input_shape=(100,),use_bias=False)) #Dense全鏈接層,input_shape=(100,)長度100的隨機向量,use_bias=False,由於後面有BN層 model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU())#激活 #第二層 model.add(layers.Dense(512,use_bias=False)) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU())#激活 #輸出層 model.add(layers.Dense(28*28*1,use_bias=False,activation='tanh')) model.add(layers.BatchNormalization()) model.add(layers.Reshape((28,28,1)))#變成圖片 要以元組形式傳入 return model # # 辨別器模型def discriminator_model(): model = keras.Sequential() model.add(layers.Flatten()) model.add(layers.Dense(512,use_bias=False)) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU())#激活 model.add(layers.Dense(256,use_bias=False)) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU())#激活 model.add(layers.Dense(1))#輸出數字,>0.5真實圖片 return model # # loss函數cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)#from_logits=True由於最後的輸出沒有激活
# # 生成器損失函數def generator_loss(fake_out):#但願fakeimage的判別輸出fake_out判別爲真 return cross_entropy(tf.ones_like(fake_out),fake_out)

# # 判別器損失函數def discriminator_loss(real_out,fake_out):#辨別器的輸出 真實圖片判1,假的圖片判0 real_loss=cross_entropy(tf.ones_like(real_out),real_out) fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out) return real_loss+fake_loss

# # 優化器
generator_opt=tf.keras.optimizers.Adam(1e-4)#學習速率discriminator_opt=tf.keras.optimizers.Adam(1e-4)
EPOCHS=500noise_dim=100 #長度爲100的隨機向量生成手寫數據集num_exp_to_generate=16 #每步生成16個樣本seed=tf.random.normal([num_exp_to_generate,noise_dim]) #生成隨機向量觀察變化狀況
# # 訓練generator=generator_model()discriminator=discriminator_model()

# # 定義批次訓練函數def train_step(images): noise = tf.random.normal([num_exp_to_generate,noise_dim]) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: #判別真圖片 real_out = discriminator(images,training=True) #生成圖片 gen_image = generator(noise,training=True) #判別生成圖片 fake_out = discriminator(gen_image,training=True) #損失函數判別 gen_loss = generator_loss(fake_out) disc_loss = discriminator_loss(real_out,fake_out) #訓練過程 #生成器與生成器可訓練參數的梯度 gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables) gradient_disc = disc_tape.gradient(disc_loss,discriminator.trainable_variables) #優化器優化梯度 generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables)) discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables)) # # 可視化def generator_plot_image(gen_model,test_noise): pre_images = gen_model(test_noise,training=False) #繪圖16張圖片在一張4x4 fig = plt.figure(figsize=(4,4)) for i in range(pre_images.shape[0]): plt.subplot(4,4,i+1) #從1開始排 plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray') #歸一化,灰色度 plt.axis('off') #不顯示座標軸 plt.show()
def train(dataset,epochs): for epoch in range(epochs): for image_batch in dataset: train_step(image_batch) #print('第'+str(epoch+1)+'次訓練結果') if epoch%10==0: print('第'+str(epoch+1)+'次訓練結果') generator_plot_image(generator,seed)
train(datasets,EPOCHS)


訓練結果:

  • 第1次訓練結果

  • 第100次訓練結果

結論:

在100次訓練後,能夠明顯看到數字的內容,到訓練了300次以後就能夠很清楚看到生成的數字效果,但300次以後,400,500次效果逐漸降低。圖片內容變模糊。


正文結束!!!!


源碼和數據集獲取方法

公衆號回覆【GAN】免費獲取

歡迎關注公衆號:Python爬蟲數據分析挖掘

記錄學習python的點點滴滴;

回覆【開源源碼】免費獲取更多開源項目源碼;

公衆號每日更新python知識和【免費】工具;

本文已同步到【開源中國】和【騰訊雲社區】;

本文分享自微信公衆號 - Python爬蟲數據分析挖掘(zyzx3344)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。

相關文章
相關標籤/搜索