0902-用GAN生成動漫頭像

0902-用GAN生成動漫頭像

 

 

pytorch完整教程目錄:python

1、概述

本節將經過 GAN 實現一個生成動漫人物頭像的例子。git

在日本的技術博客網站上有個博主,利用 DCGAN 從 20 萬張動漫頭像中學習,最終可以利用程序自動生成動漫頭像。源程序是利用 Chainer 框架實現的,在這裏咱們將嘗試利用 Pytorch 實現。github

原始的圖片是從網站中採集的,並利用 OpenCV 截取頭像,處理起來很是麻煩。所以咱們在這裏經過之乎用戶 何之源 爬取並通過處理的 5 萬張圖片,想要圖片的百度網盤連接的能夠加我微信:chenyoudea。須要注意的是,這裏圖片的分辨率是 3×96×96,而不是論文中的 3×64×64,所以須要相應地調整網絡結構,使生成圖像的尺寸爲 96。shell

2、代碼結構

下面咱們首先來看下咱們將來的一個代碼結構。微信

checkpoints/  # 無代碼,用來保存模型
imgs/  # 無代碼,用來保存生成的圖片
data/  # 無代碼,用來保存訓練所須要的圖片
main.py  # 訓練和生成
model.py  # 模型定義
visualize.py  # 可視化工具 visdom 的開發
requirement.txt  # 程序中用到的第三方庫
README.MD  # 說明
3、model.py

model.py 主要是用來定義生成器和判別器的。網絡

3.1 生成器

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by 
# Filename:model.py
# Toolby: PyCharm
from torch import nn


class NetG(nn.Module):
    """
    生成器定義
    """

    def __init__(self, opt):
        super(NetG, self).__init__()
        ngf = opt.ngf  # 生成器 feature map 數
        self.main = nn.Sequential(
            # 輸入是 nz 維度的噪聲,能夠認識它是一個 nz*1*1 的 feature map
            # H_{out} = (H_{in}-1)*stride - 2*padding + kernel_size
            # 如下面一行代碼的ConvTranspose2d舉例(初始 H_{in}=1):H_{out} = (1-1)*1-2*0+4 = 4
            nn.ConvTranspose2d(opt.nz, ngf * 8, (4, 4), (1, 1), (0, 0), bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 上一步的輸出形狀:(ngf*8)*4*4,其中(ngf*8)是輸出通道數,4 爲 H_{out} 是經過上述公式計算出來的

            # 如下面一行代碼的ConvTranspose2d舉例(初始 H_{in}=4):H_{out} = (4-1)*2-2*1+4 =8
            nn.ConvTranspose2d(ngf * 8, ngf * 4, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 上一步的輸出形狀:(ngf*4)*8*8

            nn.ConvTranspose2d(ngf * 4, ngf * 2, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 上一步的輸出形狀是:(ngf*2)*16*16

            nn.ConvTranspose2d(ngf * 2, ngf, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的輸出形狀:(ngf)*32*32

            nn.ConvTranspose2d(ngf, 3, (5, 5), (3, 3), (1, 1), bias=False),
            nn.Tanh()
            # 輸出形狀:3*96*96
        )

    def forward(self, inp):
        return self.main(inp)

從上述生成器的代碼能夠看出生成器的構建比較簡單,直接用 nn.Sequential 把上卷積、激活等操做拼接起來就好了。這裏稍微注意下 ConvTranspose2d 的使用,當 kernel size 爲 四、stride 爲 二、padding 爲 1 時,根據公式 \(H_{out} = (H_{in}-1)*stride - 2*padding + kernel_size\),輸出尺寸恰好變成輸入的兩倍。app

最後一層咱們使用了 tanh 把輸出圖片的像素歸一化至 -1~1,若是但願歸一化到 0~1,可使用 sigimoid 方法。框架

3.2 判別器

class NetD(nn.Module):
    """
    判別器定義
    """

    def __init__(self, opt):
        super(NetD, self).__init__()
        ndf = opt.ndf
        self.main = nn.Sequential(
            # 輸入 3*96*96
            nn.Conv2d(3, ndf, (5, 5), (3, 3), (1, 1), bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf)*32*32

            nn.Conv2d(ndf, ndf * 2, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf*2)*16*16

            nn.Conv2d(ndf * 2, ndf * 4, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf*4)*8*8

            nn.Conv2d(ndf * 4, ndf * 8, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf*8)*4*4

            nn.Conv2d(ndf * 8, 1, (4, 4), (1, 1), (0, 0), bias=False),
            nn.Sigmoid()  # 輸出一個數:機率
        )

    def forward(self, inp):
        return self.main(inp).view(-1)

從上述代碼能夠看到判別器和生成器的網絡結構幾乎是對稱的,從卷積核大小到 padding、stride 等設置,幾乎如出一轍。dom

須要注意的是,生成器的激活函數用的是 ReLU,而判別器使用的是 LeakyReLU,二者其實沒有太大的區別,這種選擇更多的是經驗的總結。ide

判別器的最終輸出是一個 0~1 的數,表示這個樣本是真圖片的機率。

4、參數配置

在開始寫訓練函數前,咱們能夠先配置模型參數

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by 
# Datatime:2021/5/11 15:14
# Filename:config.py
# Toolby: PyCharm
class Config(object):
    data_path = 'data/'  # 數據集存放路徑
    num_workers = 4  # 多進程加載數據所用的進程數
    image_size = 96  # 圖片尺寸
    batch_size = 256
    max_epoch = 200
    lr1 = 2e-4  # 生成器的學習率
    lr2 = 2e-4  # 判別器的學習率
    beta1 = 0.5  # Adam 優化器的 beta1 參數
    use_gpu = False  # 是否使用 GPU
    nz = 100  # 噪聲維度
    ngf = 64  # 生成器的 feature map 數
    ndf = 64  # 判別器的 feature map 數

    save_path = 'imgs/'  # 生成圖片保存路徑

    vis = True  # 是否使用 visdom 可視化
    env = 'GAN'  # visdom 的 env
    plot_every = 20  # 每隔 20 個 batch,visdom 畫圖一次

    debug_file = '/tmp/debuggan'  # 存在該文件則進入 debug 模式
    d_every = 1  # 每 1 個 batch 訓練一次判別器
    g_every = 5  # 每 5 個 batch 訓練一次生成器
    decay_everty = 10  # 每 10 個 epoch 保存一次模型
    save_every = 10  # 每 10個epoch保存一次模型
    netd_path = 'checkpoints/netd_211.pth'  # 預訓練模型
    netg_path = 'checkpoints/netg_211.pth'

    # 測試時用的參數
    gen_img = 'result.png'
    # 從 512 張生成的圖片路徑中保存最好的 64 張
    gen_num = 64
    gen_search_num = 512
    gen_mean = 0  # 噪聲的均值
    gen_std = 1  # 噪聲的方差
    
opt = Config()

上述這些都只是模型的默認參數,還能夠利用 Fire 等工具經過命令行傳入,覆蓋默認值。

除此以外,還可使用 opt.atrr,還能夠利用 IDE/Python 提供的自動補全功能,十分方便。

上述的超參數大可能是照搬 DCGAN 論文的默認值,這些默認值都是坐着通過大量的實驗,發現這些參數可以更快地去訓練出一個不錯的模型。

5、數據處理

當咱們下載完數據以後,須要把全部圖片放在一文件夾,而後把文件夾移動到 data 目錄下(而且要確保 data 下沒有其餘的文件夾)。使用這種方法是爲了可以直接使用 pytorchvision 自帶的 ImageFolder 讀取圖片,而沒有必要本身寫一個 Dataset。

數據讀取和加載的代碼以下所示。

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by 
# Datatime:2021/5/12 09:43
# Filename:dataset.py
# Toolby: PyCharm
import torch as t
import torchvision as tv
from torch.utils.data import DataLoader

from config import opt

# 數據處理,輸出規模爲 -1~1
transforms = tv.transforms.Compose([
    tv.transforms.Scale(opt.image_size),
    tv.transforms.CenterCrop(opt.image_size),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加載數據集
dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
dataloader = DataLoader(
    dataset,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.num_workers,
    drop_last=True
)

從上述代碼中能夠發現,用 ImageFolder 配合 DataLoader 加載圖片十分方便。

6、訓練

在訓練以前,咱們還須要定義幾個變量:模型、優化器、噪聲等。

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by 
# Datatime:2021/5/10 10:37
# Filename:main.py
# Toolby: PyCharm
import os
import ipdb
import tqdm
import fire
import torch as t
import torchvision as tv
from visualize import Visualizer
from torch.autograd import Variable
from torchnet.meter import AverageValueMeter

from config import opt
from dataset import dataloader
from model import NetD, NetG



def train(**kwargs):
    # 定義模型
    netd = NetD()
    netg = NetG()
    # 定義網絡
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    # 定義優化器和損失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss()

    # 真圖片 label 爲 1,假圖片 label 爲 0,noises 爲生成網絡的輸入噪聲
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    noises = vars(t.randn(opt.batch_size, opt.nz, 1, 1))

    # 若是使用 GPU 訓練,把數據轉移到 GPU 上
    if opt.use_gpu:
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

在加載預訓練模型的時候,最好指定 map_location。由於若是程序以前在 GPU 上運行,那麼模型就會被存成 torch.cuda.Tensor,這樣加載的時候會默認把數據加載到顯存上。若是運行該程序的計算機中沒有 GPU,則會報錯,所以指定 map_location 把 Tensor 默認加載到內存上,等有須要的時候再加載到顯存中。

下面開始訓練網絡,訓練的步驟以下所示:

  1. 訓練判別器:
    • 固定生成器
    • 對於真圖片,判別器的輸出機率值儘量接近 1
    • 對於生成器生成的圖片,判別器儘量輸出 0
  2. 訓練生成器
    • 固定判別器
    • 生成器生成圖片,儘量讓判別器輸出 1
  3. 返回第一步,循環交替訓練
epochs = range(opt.max_epoch)
    for epoch in iter(epochs):

        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = Variable(img)
            if opt.use_gpu:
                real_img = real_img.cuda()

            # 訓練判別器
            if (ii + 1) % opt.d_every == 0:
                optimizer_d.zero_grad()
                # 儘量把真圖片判別爲 1
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                # 儘量把假圖片判別爲 0
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根據照片生成假圖片
                fake_ouput = netd(fake_img)
                error_d_fake = criterion(fake_ouput, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

            # 訓練生成器
            if (ii + 1) % opt.g_every == 0:
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                fake_output = netd(fake_img)
                # 儘量讓判別器把假圖片也判別爲 1
                error_g = criterion(fake_output, true_labels)
                error_g.backward()
                optimizer_g.step()

            # 可視化

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                # 定義可視化窗口
                vis = Visualizer(opt.env)

                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                global fix_fake_imgs
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 保存模型、圖片
            tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()

在上述訓練代碼中,須要注意如下幾點:

  • 訓練生成器的時候,不須要調整判別器的參數;訓練判別器的時候,也不須要調整生成器的參數
  • 在訓練判別器的時候,須要對生成器生成的圖片用 detach 操做進行計算圖截斷,避免反向傳播把梯度傳到生成器中。由於在訓練判別器的時候咱們不須要訓練生成器,也就不須要生成器的梯度。
  • 在訓練分類器的時候,須要反向傳播兩次,一次是但願把真圖片判爲 1,一次是但願把假圖片判爲 0.也能夠把這個二者的數據放到一個 batch 中,進行一次前向傳播和一次反向傳播便可。可是人們發現,在一個 batch 中只包含真圖片或者只包含假圖片的作法最好。
  • 對於假圖片,在訓練判別器的時候,咱們但願它輸出爲 0;而在訓練生成器的時候,咱們但願它輸出爲 1.所以能夠看到一堆相互矛盾的代碼:error_d_fake = criterion(fake_output,fake_labels)error_g = criterion(fake_output, true_labels)。其實這也很好理解,判別器但願可以把假圖片判別爲 fake_label,而生成器但願能把它判別爲 true_label,判別器和生成器相互對抗提高。
  • 其中的 Visualize 模塊相似於上一章本身的寫的模塊,能夠直接複製粘貼源碼中的代碼。
7、隨機生成圖片

除了上述所示的代碼外,還提供了一個函數,能加載預訓練好的模型,而且利用噪聲隨機生成圖片。

@t.no_grad()
def generate():
    # 定義噪聲和網絡
    netg, netd = NetG(opt), NetD(opt)
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    noises = Variable(noises)

    # 加載預訓練的模型
    netd.load_state_dict(t.load(opt.netd_path))
    netg.load_state_dict(t.load(opt.netg_path))

    # 是否使用 GPU
    if opt.use_gpu:
        netd.cuda()
        netg.cuda()
        noises = noises.cuda()

    # 生成圖片,並計算圖片在判別器的分數
    fake_img = netg(noises)
    scores = netd(fake_img).data

    # 挑選最好的某幾張
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])

    # 保存圖片
    tv.utils.save_image(t.stack(result), opt.gen_num, normalize=True, range=(-1, 1))
8、訓練模型並測試

完整的代碼能夠添加我微信:chenyoudea,其實上述代碼已經很完整了,或者去github https://github.com/chenyuntc/pytorch-book/tree/master/chapter07-AnimeGAN下載。

這裏假設你是擁有完整的代碼,那麼準備好數據後,能夠用下面的命令開始訓練:

python main.py train --gpu=True --vis=True --batch-size=256 --max-epoch=200

若是使用了 visdom,此時打開 http://localhost:8097 就能看到生成的圖像。

訓練完成後,咱們就能夠利用生成網絡隨機生成動漫頭像,輸入命令以下:

python main.py generate --gen-img='result.5w.png' --gen-search-num=15000

下圖是 10 個 epoch 的展現:
watermark,size_16,text_QDUxQ1RP5Y2a5a6i,color_FFFFFF,t_100,g_se,x_10,y_10,shadow_90,type_ZmFuZ3poZW5naGVpdGk=

相關文章
相關標籤/搜索