GAN生成某某圖片數據估計已經被各大博客作爛了。我只是貼一下個人理解和個人步驟。各位加油,找到一個好博客努力搞懂。
文末有完整代碼。最好是看着代碼就着思路講解下飯。python
generative adversirial network,經典理論主要由兩個部分組成,generator和discriminator,generator生成和數據集類似的新圖片,讓discriminator分辨這個圖片是真實圖片仍是生成的圖片。兩者對立統一,discriminator分辨能力提升,促使generator生成更接近真實圖片的圖片;generator生成更「真」的圖片後促使discriminator提升辨識能力。
固然咱們想要的東西每每是generator,去生成新圖片(或者其餘數據)。
本身訓練的時候注意:網絡
導入包,image_size是28×28,後面會用到,我把一張圖片直接做爲一個[1,28×28]的向量來處理了。
而後一個工具類,爲了查看dataloader裏面的數據,其實沒什麼用,你本身能夠寫個本身的版本的。dom
import torch import torch.utils.data import torch.nn as nn import matplotlib.pyplot as plt import torchvision.transforms as transforms import torchvision DEVICE= 'cuda' if torch.cuda.is_available() else 'cpu' BATCH_SIZE=128 IMAGE_SIZE= 28 * 28 # it denotes vector length as well def ShowDataLoader(dataloader,num): i=0 for imgs,labs in dataloader: print("imgs",imgs.shape) print("labs", labs.shape) i += 1 if i==num: break return
加載圖片天然少不了先對圖片預處理,torchvision提供了大量的函數幫助。transform先變成tensor,而後對其進行正則化。因爲我不想用那麼多的數據,因此我對數據進行了分割,只拿出了32*1500條數據。
裏米那個MNIST若是沒有數據集,能夠直接下載的,download=True設置一下就能夠了。
而後pytroch dataloader加載進去,pytorch基本操做了。ide
if __name__ == '__main__': transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5],std=[0.5]) ]) #load mnist mnist=torchvision.datasets.MNIST(root="./data-source",train=True,transform=transform) #split data set whole_length=len(mnist) # data length should be number that can be divided by bath_size sub_length=32*1500 sub_minist1,sub_minist2=torch.utils.data.random_split(mnist, [sub_length, whole_length - sub_length]) #load dataset dataloader=torch.utils.data.DataLoader(dataset=sub_minist1, batch_size=BATCH_SIZE,shuffle=True) # plt.imshow(next(iter(dataloader))[0][0][0]) # plt.show()
用sequential方便一點,聲明瞭兩個網絡。
選用BCELoss做爲loss function,optimizer也能夠用別的,兩個網絡一人一個optimizer。函數
discriminator的輸入好說,無論是真的仍是假的,就是一批batch的圖片。
generator的輸入經典論文是推薦一個latent vector,其實就是個隨機生成的向量,通過generator變化後,把它生成到一個batch的圖片數據。工具
learning rate也可調,具體看訓練狀況。學習
Discriminitor=nn.Sequential( nn.Linear(IMAGE_SIZE, 300), nn.LeakyReLU(0.2), nn.Linear(300,150), nn.LeakyReLU(0.2), nn.Linear(150,1), nn.Sigmoid() ) Discriminitor = Discriminitor.to(DEVICE) latent_size=64 Generator=nn.Sequential( nn.Linear(latent_size,150), nn.ReLU(True), nn.Linear(150,300), nn.ReLU(True), nn.Linear(300, IMAGE_SIZE), nn.Tanh()#change it into range of (-1,1) ) Generator=Generator.to(DEVICE) loss_fn=nn.BCELoss() d_optimizer=torch.optim.SGD(Discriminitor.parameters(), lr=0.002) g_optimizer=torch.optim.Adam(Generator.parameters(), lr=0.002)
訓練代碼格式,pytorch經典寫法我就沒必要多解釋。
我實現的時候秉着這樣的想法:code
loader_len=len(dataloader) EPOCH =30 G_EPOCH=1 for epoch in range(EPOCH): for i,(images, _) in enumerate(dataloader): images=images.reshape(images.shape[0], IMAGE_SIZE).to(DEVICE) # noise distraction # noise=torch.randn(images.shape[0], IMAGE_SIZE) # images=noise+images #make labels for training label_real_pic = torch.ones(BATCH_SIZE, 1).to(DEVICE) label_fake_pic = torch.zeros(BATCH_SIZE, 1).to(DEVICE) #have a glance at real image if i%100==0: plt.title('real') data=images.view(BATCH_SIZE, 28, 28).data.cpu().numpy() plt.imshow(data[0]) plt.pause(1) #calculate loss of the "real part" res_real=Discriminitor(images) d_loss_real=loss_fn(res_real, label_real_pic) #calculate loss of the "fake part" #generate fake image z=torch.randn(BATCH_SIZE,latent_size).to(DEVICE) fake_imgs=Generator(z) res_fake=Discriminitor(fake_imgs.detach()) #detach means to fix the param. d_loss_fake=loss_fn(res_fake,label_fake_pic) d_loss=d_loss_fake+d_loss_real #update discriminator model d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() #change G_EPOCH to modify epoch of discriminator for dummy in range(G_EPOCH): tt=torch.randn(BATCH_SIZE,latent_size).to(DEVICE) fake1=Generator(tt) res_fake2=Discriminitor(fake1) g_loss=loss_fn(res_fake2,label_real_pic) g_optimizer.zero_grad() g_loss.backward() g_optimizer.step() if i %50==0: print("Epoch [{}/{}], Step [ {}/{} ], d_loss: {:.4f}, g_loss: {:.4f}, " .format(epoch,EPOCH,i,loader_len,d_loss.item(),g_loss.item())) #take a look at how generator is at the moment temp = torch.randn(BATCH_SIZE, latent_size).to(DEVICE) fake_temp = Generator(temp) ff = fake_temp.view(BATCH_SIZE, 28, 28).data.cpu().numpy() plt.title('generated') plt.imshow(ff[0]) plt.pause(1)
我這個程序會隔一段時間顯示當前訓練的batch中的一張,真實數據和生成數據都有,表明當前狀況。
其中生成數據能夠看做當前generator能生成到什麼程度了。固然你也能夠看兩方的loss,在console能夠看到。orm
import torch import torch.utils.data import torch.nn as nn import matplotlib.pyplot as plt import torchvision.transforms as transforms import torchvision DEVICE= 'cuda' if torch.cuda.is_available() else 'cpu' BATCH_SIZE=128 IMAGE_SIZE= 28 * 28 # it denotes vector length as well def ShowDataLoader(dataloader,num): i=0 for imgs,labs in dataloader: print("imgs",imgs.shape) print("labs", labs.shape) i += 1 if i==num: break return if __name__ == '__main__': transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5],std=[0.5]) ]) #load mnist mnist=torchvision.datasets.MNIST(root="./data-source",train=True,transform=transform) #split data set whole_length=len(mnist) # data length should be number that can be divided by bath_size sub_length=32*1500 sub_minist1,sub_minist2=torch.utils.data.random_split(mnist, [sub_length, whole_length - sub_length]) #load dataset dataloader=torch.utils.data.DataLoader(dataset=sub_minist1, batch_size=BATCH_SIZE,shuffle=True) # plt.imshow(next(iter(dataloader))[0][0][0]) # plt.show() Discriminitor=nn.Sequential( nn.Linear(IMAGE_SIZE, 300), nn.LeakyReLU(0.2), nn.Linear(300,150), nn.LeakyReLU(0.2), nn.Linear(150,1), nn.Sigmoid() ) Discriminitor = Discriminitor.to(DEVICE) latent_size=64 Generator=nn.Sequential( nn.Linear(latent_size,150), nn.ReLU(True), nn.Linear(150,300), nn.ReLU(True), nn.Linear(300, IMAGE_SIZE), nn.Tanh()#change it into range of (-1,1) ) Generator=Generator.to(DEVICE) loss_fn=nn.BCELoss() d_optimizer=torch.optim.SGD(Discriminitor.parameters(), lr=0.002) g_optimizer=torch.optim.Adam(Generator.parameters(), lr=0.002) loader_len=len(dataloader) EPOCH =30 G_EPOCH=1 for epoch in range(EPOCH): for i,(images, _) in enumerate(dataloader): images=images.reshape(images.shape[0], IMAGE_SIZE).to(DEVICE) # noise=torch.randn(images.shape[0], IMAGE_SIZE) # images=noise+images label_real_pic = torch.ones(BATCH_SIZE, 1).to(DEVICE) label_fake_pic = torch.zeros(BATCH_SIZE, 1).to(DEVICE) if i%100==0: plt.title('real') data=images.view(BATCH_SIZE, 28, 28).data.cpu().numpy() plt.imshow(data[0]) plt.pause(1) res_real=Discriminitor(images) d_loss_real=loss_fn(res_real, label_real_pic) #generate fake image z=torch.randn(BATCH_SIZE,latent_size).to(DEVICE) fake_imgs=Generator(z) res_fake=Discriminitor(fake_imgs.detach()) #detach means to fix the param. d_loss_fake=loss_fn(res_fake,label_fake_pic) d_loss=d_loss_fake+d_loss_real d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() for j in range(G_EPOCH): tt=torch.randn(BATCH_SIZE,latent_size).to(DEVICE) fake1=Generator(tt) res_fake2=Discriminitor(fake1) g_loss=loss_fn(res_fake2,label_real_pic) g_optimizer.zero_grad() g_loss.backward() g_optimizer.step() if i %50==0: print("Epoch [{}/{}], Step [ {}/{} ], d_loss: {:.4f}, g_loss: {:.4f}, " .format(epoch, EPOCH, i, loader_len, d_loss.item(), g_loss.item())) temp = torch.randn(BATCH_SIZE, latent_size).to(DEVICE) fake_temp = Generator(temp) ff = fake_temp.view(BATCH_SIZE, 28, 28).data.cpu().numpy() plt.title('generated') plt.imshow(ff[0]) plt.pause(1)