pytorch GAN生成相似MNIST數據

提早說明

GAN生成某某圖片數據估計已經被各大博客作爛了。我只是貼一下個人理解和個人步驟。各位加油,找到一個好博客努力搞懂。
文末有完整代碼。最好是看着代碼就着思路講解下飯。python

GAN

generative adversirial network,經典理論主要由兩個部分組成,generator和discriminator,generator生成和數據集類似的新圖片,讓discriminator分辨這個圖片是真實圖片仍是生成的圖片。兩者對立統一,discriminator分辨能力提升,促使generator生成更接近真實圖片的圖片;generator生成更「真」的圖片後促使discriminator提升辨識能力。
固然咱們想要的東西每每是generator,去生成新圖片(或者其餘數據)。
本身訓練的時候注意:網絡

  1. generator和discriminator要「平分秋色」,才能獲得好的generator
  2. 能夠適當調節二者的學習進度,好比在真圖片中加噪聲干擾discriminator學習、調整兩者學習率、調整訓練次數(好比訓練1次discriminator就訓練5次generator)等等。

實現思路

準備

導入包,image_size是28×28,後面會用到,我把一張圖片直接做爲一個[1,28×28]的向量來處理了。
而後一個工具類,爲了查看dataloader裏面的數據,其實沒什麼用,你本身能夠寫個本身的版本的。dom

import torch
import torch.utils.data
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision

DEVICE= 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE=128
IMAGE_SIZE= 28 * 28 # it denotes vector length as well

def ShowDataLoader(dataloader,num):
    i=0
    for imgs,labs in dataloader:
        print("imgs",imgs.shape)
        print("labs", labs.shape)
        i += 1
        if i==num: break
    return

加載數據集

加載圖片天然少不了先對圖片預處理,torchvision提供了大量的函數幫助。transform先變成tensor,而後對其進行正則化。因爲我不想用那麼多的數據,因此我對數據進行了分割,只拿出了32*1500條數據。
裏米那個MNIST若是沒有數據集,能夠直接下載的,download=True設置一下就能夠了。
而後pytroch dataloader加載進去,pytorch基本操做了。ide

if __name__ == '__main__':

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5],std=[0.5])
    ])

    #load mnist
    mnist=torchvision.datasets.MNIST(root="./data-source",train=True,transform=transform)
    #split data set
    whole_length=len(mnist)
    # data length should be number that can be divided by bath_size
    sub_length=32*1500
    sub_minist1,sub_minist2=torch.utils.data.random_split(mnist, [sub_length, whole_length - sub_length])


    #load dataset
    dataloader=torch.utils.data.DataLoader(dataset=sub_minist1, batch_size=BATCH_SIZE,shuffle=True)
    # plt.imshow(next(iter(dataloader))[0][0][0])
    # plt.show()

建立discriminator和generator等

用sequential方便一點,聲明瞭兩個網絡。
選用BCELoss做爲loss function,optimizer也能夠用別的,兩個網絡一人一個optimizer。函數

discriminator的輸入好說,無論是真的仍是假的,就是一批batch的圖片。
generator的輸入經典論文是推薦一個latent vector,其實就是個隨機生成的向量,通過generator變化後,把它生成到一個batch的圖片數據。工具

learning rate也可調,具體看訓練狀況。學習

Discriminitor=nn.Sequential(
        nn.Linear(IMAGE_SIZE, 300),
        nn.LeakyReLU(0.2),
        nn.Linear(300,150),
        nn.LeakyReLU(0.2),
        nn.Linear(150,1),
        nn.Sigmoid()
    )
    Discriminitor = Discriminitor.to(DEVICE)


    latent_size=64
    Generator=nn.Sequential(
        nn.Linear(latent_size,150),
        nn.ReLU(True),
        nn.Linear(150,300),
        nn.ReLU(True),
        nn.Linear(300, IMAGE_SIZE),
        nn.Tanh()#change it into range of (-1,1)
    )

    Generator=Generator.to(DEVICE)

    loss_fn=nn.BCELoss()

    d_optimizer=torch.optim.SGD(Discriminitor.parameters(), lr=0.002)
    g_optimizer=torch.optim.Adam(Generator.parameters(), lr=0.002)

訓練

訓練代碼格式,pytorch經典寫法我就沒必要多解釋。
我實現的時候秉着這樣的想法:code

  1. 經過矩陣變換,變成一張圖片一個向量,維度[1,28*28]。
  2. 先計算discriminator的loss,其來自兩部分,一部分是真實數據,一部分是generator生成的數據。
  3. 因爲咱們的label並非MNIST的label,而是應該表示圖片真假的label,因此咱們須要本身作label。
  4. 兩部分loss進行backward以後,discriminator的訓練就算完成。
  5. generator生成的圖片若是被discriminator判別程假的那麼就是失敗的,因此其但願其生成的圖片的標籤應當是「真」。因此loss是discriminator的結果和全真向量所比較的BCEloss。
loader_len=len(dataloader)
    EPOCH =30
    G_EPOCH=1
    for epoch in range(EPOCH):
        for i,(images, _) in enumerate(dataloader):
            images=images.reshape(images.shape[0], IMAGE_SIZE).to(DEVICE)

            # noise distraction
            # noise=torch.randn(images.shape[0], IMAGE_SIZE)
            # images=noise+images
            
            #make labels for training
            label_real_pic = torch.ones(BATCH_SIZE, 1).to(DEVICE)
            label_fake_pic = torch.zeros(BATCH_SIZE, 1).to(DEVICE)
            
            #have a glance at real image
            if i%100==0:
                plt.title('real')
                data=images.view(BATCH_SIZE, 28, 28).data.cpu().numpy()
                plt.imshow(data[0])
                plt.pause(1)

            #calculate loss of the "real part"
            res_real=Discriminitor(images)
            d_loss_real=loss_fn(res_real, label_real_pic)
            
            #calculate loss of the "fake part"
            #generate fake image
            z=torch.randn(BATCH_SIZE,latent_size).to(DEVICE)
            fake_imgs=Generator(z)

            res_fake=Discriminitor(fake_imgs.detach()) #detach means to fix the param.
            d_loss_fake=loss_fn(res_fake,label_fake_pic)

            d_loss=d_loss_fake+d_loss_real
            
            #update discriminator model
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            #change G_EPOCH to modify epoch of discriminator
            for dummy in range(G_EPOCH):
                tt=torch.randn(BATCH_SIZE,latent_size).to(DEVICE)
                fake1=Generator(tt)
                res_fake2=Discriminitor(fake1)
                g_loss=loss_fn(res_fake2,label_real_pic)
                g_optimizer.zero_grad()
                g_loss.backward()
                g_optimizer.step()

            if i %50==0:
                print("Epoch [{}/{}], Step [ {}/{} ], d_loss: {:.4f}, g_loss: {:.4f}, "
                      .format(epoch,EPOCH,i,loader_len,d_loss.item(),g_loss.item()))
                #take a look at how generator is at the moment
                temp = torch.randn(BATCH_SIZE, latent_size).to(DEVICE)
                fake_temp = Generator(temp)
                ff = fake_temp.view(BATCH_SIZE, 28, 28).data.cpu().numpy()
                plt.title('generated')
                plt.imshow(ff[0])
                plt.pause(1)

效果

我這個程序會隔一段時間顯示當前訓練的batch中的一張,真實數據和生成數據都有,表明當前狀況。
其中生成數據能夠看做當前generator能生成到什麼程度了。固然你也能夠看兩方的loss,在console能夠看到。orm

完整代碼

import torch
import torch.utils.data
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision


DEVICE= 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE=128
IMAGE_SIZE= 28 * 28 # it denotes vector length as well

def ShowDataLoader(dataloader,num):
    i=0
    for imgs,labs in dataloader:
        print("imgs",imgs.shape)
        print("labs", labs.shape)
        i += 1
        if i==num: break
    return


if __name__ == '__main__':

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5],std=[0.5])
    ])

    #load mnist
    mnist=torchvision.datasets.MNIST(root="./data-source",train=True,transform=transform)
    #split data set
    whole_length=len(mnist)
    # data length should be number that can be divided by bath_size
    sub_length=32*1500
    sub_minist1,sub_minist2=torch.utils.data.random_split(mnist, [sub_length, whole_length - sub_length])


    #load dataset
    dataloader=torch.utils.data.DataLoader(dataset=sub_minist1, batch_size=BATCH_SIZE,shuffle=True)
    # plt.imshow(next(iter(dataloader))[0][0][0])
    # plt.show()



    Discriminitor=nn.Sequential(
        nn.Linear(IMAGE_SIZE, 300),
        nn.LeakyReLU(0.2),
        nn.Linear(300,150),
        nn.LeakyReLU(0.2),
        nn.Linear(150,1),
        nn.Sigmoid()
    )
    Discriminitor = Discriminitor.to(DEVICE)


    latent_size=64
    Generator=nn.Sequential(
        nn.Linear(latent_size,150),
        nn.ReLU(True),
        nn.Linear(150,300),
        nn.ReLU(True),
        nn.Linear(300, IMAGE_SIZE),
        nn.Tanh()#change it into range of (-1,1)
    )

    Generator=Generator.to(DEVICE)

    loss_fn=nn.BCELoss()

    d_optimizer=torch.optim.SGD(Discriminitor.parameters(), lr=0.002)
    g_optimizer=torch.optim.Adam(Generator.parameters(), lr=0.002)


    loader_len=len(dataloader)
    EPOCH =30
    G_EPOCH=1
    for epoch in range(EPOCH):
        for i,(images, _) in enumerate(dataloader):
            images=images.reshape(images.shape[0], IMAGE_SIZE).to(DEVICE)

            # noise=torch.randn(images.shape[0], IMAGE_SIZE)
            # images=noise+images
            
            label_real_pic = torch.ones(BATCH_SIZE, 1).to(DEVICE)
            label_fake_pic = torch.zeros(BATCH_SIZE, 1).to(DEVICE)

            if i%100==0:
                plt.title('real')
                data=images.view(BATCH_SIZE, 28, 28).data.cpu().numpy()
                plt.imshow(data[0])
                plt.pause(1)

            res_real=Discriminitor(images)

            d_loss_real=loss_fn(res_real, label_real_pic)

            #generate fake image
            z=torch.randn(BATCH_SIZE,latent_size).to(DEVICE)
            fake_imgs=Generator(z)

            res_fake=Discriminitor(fake_imgs.detach()) #detach means to fix the param.

            d_loss_fake=loss_fn(res_fake,label_fake_pic)
            d_loss=d_loss_fake+d_loss_real

            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()


            for j in range(G_EPOCH):
                tt=torch.randn(BATCH_SIZE,latent_size).to(DEVICE)
                fake1=Generator(tt)
                res_fake2=Discriminitor(fake1)
                g_loss=loss_fn(res_fake2,label_real_pic)
                g_optimizer.zero_grad()
                g_loss.backward()
                g_optimizer.step()

            if i %50==0:
                print("Epoch [{}/{}], Step [ {}/{} ], d_loss: {:.4f}, g_loss: {:.4f}, "
                      .format(epoch, EPOCH, i, loader_len, d_loss.item(), g_loss.item()))

                temp = torch.randn(BATCH_SIZE, latent_size).to(DEVICE)
                fake_temp = Generator(temp)
                ff = fake_temp.view(BATCH_SIZE, 28, 28).data.cpu().numpy()
                plt.title('generated')
                plt.imshow(ff[0])
                plt.pause(1)
相關文章
相關標籤/搜索