用MXNet實現mnist的生成對抗網絡(GAN)

用MXNet實現mnist的生成對抗網絡(GAN)

生成式對抗網絡(Generative Adversarial Network,簡稱GAN)由一個生成網絡與一個判別網絡組成。生成網絡從潛在空間(latent space)中隨機採樣做爲輸入,其輸出結果須要儘可能模仿訓練集中的真實樣本。判別網絡的輸入則爲真實樣本或生成網絡的輸出,其目的是將生成網絡的輸出從真實樣本中儘量分辨出來。而生成網絡則要儘量地欺騙判別網絡。兩個網絡相互對抗、不斷調整參數,最終目的是使判別網絡沒法判斷生成網絡的輸出結果是否真實。從數據的分佈來看就是使得生成的數據分佈\(P_z(z)\)與原來的數據\(P_{data}(x)\)十分接近,理想的狀況下爲\(P_z(z)=P_{data}(x)\)。本文給出了GAN的Loss函數、說明GAN的訓練原理,再結合最簡單的例子mnist,用MXNet來實現GAN。html

GAN的基本概念

在同樣樣本中加入一些精心編制的噪聲,會使得原來的分類器失效。圖1是一個廣爲流傳的示例,左邊的分類器獲得的是熊貓而右邊被分類爲了長臂猿。python

wrong

圖1 誤分類的示例

爲何會有這樣的結果?圖像分類器本質上是多維空間中的決策邊界,當訓練的樣本不足時,可能會使得分類器過擬合。當向原樣本中加入一些L2範數很小的噪聲時,人類的視覺是沒法分別這些細微的差異,因此依然會認爲和原樣本的分類沒什麼區別。但對過擬合的分類器來講,輸入樣本的小誤差可能使得最後的決策點越過了原來的決策邊界,進入到其它分類中了。這就致使了錯誤的分類。git

對於生成網絡設爲G,\(G(Z)\)爲生成的對抗樣本,理想條件下\(G(z)\)隨機生成的樣本分佈與真實樣本分佈是同樣。對於判別網絡設爲D,\(D(x)\)爲判別樣本是真實的機率,理想條件下對真實樣本有\(G(x)=1\),對生成樣本有\(D(G(z))=0\)。爲了達到效果,設計瞭如圖2所示的網絡結構:github

net

圖2 GAN的網絡結構

Loss函數以下:算法

\[ V(G,D)=E_{x-p_{data}(x)}[\log(D(x))] + E_{z-p_{z}(z)}[1-\log(D(G(z)))] \tag{1.1} \]apache

這個Loss函數的優化方法與EM算法的思想是類似的:在G是固定的狀況下,判別網絡D的精確率越高,那麼V就越大;在D固定的條件下,生成網絡G的生成的樣本越像實際樣本,那麼V就越小。全部V(G,D)進行了極小極大化博弈:bash

\[ \min_G \max_D V(G,D)=E_{x-p_{data}(x)}[\log(D(x))] + E_{z-p_{z}(z)}[1-\log(D(G(z)))] \tag{1.2} \]網絡

實現mnist的GAN

MXNet的源碼給出了mnsit的GAN實現(見dcgan.py),可是沒有給出詳細的說明,我在這裏詳細解釋下,源文件在裝了相關的python包以後是能正確運行的。DCGAN是指Deep Convolution Generative Adversarial Netword(深度卷積生成式對抗網格)。app

mnist的網絡相對來講比較簡單,如圖所示:ide

D_G

圖3 D是判別式網絡,G是生成式網絡,能夠看到兩個網絡輸出的數據大體成反向對稱

生成網絡G的結構與判別網絡D的結果是反向對稱的(雖然兩個網絡的開頭或者結尾有所不一樣,但這是爲了與結果相對應),這裏有一個很重要但被不少文章忽略的假設:判別網絡從潛在空間(latent space)是可逆的。不是說從最後的結果是可逆的,但從原始圖片映射到潛在空間這個過程(好比說從全鏈接層的n(n通常比較大)維向量)是可逆的,這裏說的可逆不是嚴格意義上的反函數,而是從視覺判別結果上區別不大,好比說在G與D理想的狀況下數字9經過判別網絡獲得一個100維的向量,再將這個100維向量經過生成網絡G獲得一張圖片,這張圖片在人類看來也是9。

代碼實現以下:

def make_dcgan_sym(ngf, ndf, nc, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12):
    BatchNorm = mx.sym.BatchNorm
    # 生成網絡G
    # 輸入生成網絡G的變量,這個是潛在空間
    rand = mx.sym.Variable('rand')

    g1 = mx.sym.Deconvolution(rand, name='g1', kernel=(4,4), num_filter=ngf*8, no_bias=no_bias)
    gbn1 = BatchNorm(g1, name='gbn1', fix_gamma=fix_gamma, eps=eps)
    gact1 = mx.sym.Activation(gbn1, name='gact1', act_type='relu')

    g2 = mx.sym.Deconvolution(gact1, name='g2', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ngf*4, no_bias=no_bias)
    gbn2 = BatchNorm(g2, name='gbn2', fix_gamma=fix_gamma, eps=eps)
    gact2 = mx.sym.Activation(gbn2, name='gact2', act_type='relu')

    g3 = mx.sym.Deconvolution(gact2, name='g3', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ngf*2, no_bias=no_bias)
    gbn3 = BatchNorm(g3, name='gbn3', fix_gamma=fix_gamma, eps=eps)
    gact3 = mx.sym.Activation(gbn3, name='gact3', act_type='relu')

    g4 = mx.sym.Deconvolution(gact3, name='g4', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ngf, no_bias=no_bias)
    gbn4 = BatchNorm(g4, name='gbn4', fix_gamma=fix_gamma, eps=eps)
    gact4 = mx.sym.Activation(gbn4, name='gact4', act_type='relu')

    g5 = mx.sym.Deconvolution(gact4, name='g5', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=nc, no_bias=no_bias)
    # 生成網絡G最後獲得一張相片
    gout = mx.sym.Activation(g5, name='gact5', act_type='tanh')

    # 判別網絡D,這裏裏的結構與通常的分類網絡區別不大
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')

    d1 = mx.sym.Convolution(data, name='d1', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf, no_bias=no_bias)
    dact1 = mx.sym.LeakyReLU(d1, name='dact1', act_type='leaky', slope=0.2)

    d2 = mx.sym.Convolution(dact1, name='d2', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf*2, no_bias=no_bias)
    dbn2 = BatchNorm(d2, name='dbn2', fix_gamma=fix_gamma, eps=eps)
    dact2 = mx.sym.LeakyReLU(dbn2, name='dact2', act_type='leaky', slope=0.2)

    d3 = mx.sym.Convolution(dact2, name='d3', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf*4, no_bias=no_bias)
    dbn3 = BatchNorm(d3, name='dbn3', fix_gamma=fix_gamma, eps=eps)
    dact3 = mx.sym.LeakyReLU(dbn3, name='dact3', act_type='leaky', slope=0.2)

    d4 = mx.sym.Convolution(dact3, name='d4', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf*8, no_bias=no_bias)
    dbn4 = BatchNorm(d4, name='dbn4', fix_gamma=fix_gamma, eps=eps)
    dact4 = mx.sym.LeakyReLU(dbn4, name='dact4', act_type='leaky', slope=0.2)

    d5 = mx.sym.Convolution(dact4, name='d5', kernel=(4,4), num_filter=1, no_bias=no_bias)
    d5 = mx.sym.Flatten(d5)
    # 用邏輯迴歸計算最後的loss
    dloss = mx.sym.LogisticRegressionOutput(data=d5, label=label, name='dloss')
    # 返回這G與D這兩個網絡
    return gout, dloss

在訓練的過程當中,全部的原樣本的label爲1,生成網絡G生成的樣本的label爲0,用這樣來區別原樣本與生成的對抗樣本。生成網絡輸入的潛在空間樣本是100維的,訓練過程以下:

  • 用生成網絡G生成對抗樣本gout
  • 對抗樣本的label設爲0,由於要先用這個訓練判別網絡D
  • 用gout來訓練判別網絡D,獲得梯度,但不更新
  • 對原樣本的label設爲1,再用之來訓練判別網絡D
  • 獲得梯度後合入gout獲得的梯度,更新D的參數
  • 下面的過程是爲了獲得生成網絡G的loss
    • 設gout的label爲1,由於生成網絡G的目標就是要生成label爲1的樣本,因此訓練G的label爲1。反之,若是訓練D,爲了區別原樣本與生成樣本因此label爲0。
    • 用判別網絡D來得輸入的梯度dgout,這個梯度就是生成網絡G的loss。
  • 用這個loss反向傳播生成網絡G,並更新參數。

這裏面的關鍵就是用判別網絡D來獲得生成網絡G的loss,之因此能夠這樣,是由於這兩個網絡是可逆的。訓練的代碼以下:

if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)

    # =============setting============
    dataset = 'mnist'
    imgnet_path = './train.rec'
    ndf = 64
    ngf = 64
    nc = 3
    batch_size = 64
    Z = 100
    lr = 0.0002
    beta1 = 0.5
    ctx = mx.gpu(0)
    check_point = False

    symG, symD = make_dcgan_sym(ngf, ndf, nc)
    #mx.viz.plot_network(symG, shape={'rand': (batch_size, 100, 1, 1)}).view()
    #mx.viz.plot_network(symD, shape={'data': (batch_size, nc, 64, 64)}).view()

    # ==============data==============
    if dataset == 'mnist':
        X_train, X_test = get_mnist()
        train_iter = mx.io.NDArrayIter(X_train, batch_size=batch_size)
    elif dataset == 'imagenet':
        train_iter = ImagenetIter(imgnet_path, batch_size, (3, 64, 64))
    rand_iter = RandIter(batch_size, Z)
    label = mx.nd.zeros((batch_size,), ctx=ctx)

    # =============module G=============
    modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctx)
    modG.bind(data_shapes=rand_iter.provide_data)
    modG.init_params(initializer=mx.init.Normal(0.02))
    modG.init_optimizer(
        optimizer='adam',
        optimizer_params={
            'learning_rate': lr,
            'wd': 0.,
            'beta1': beta1,
        })
    mods = [modG]

    # =============module D=============
    modD = mx.mod.Module(symbol=symD, data_names=('data',), label_names=('label',), context=ctx)
    modD.bind(data_shapes=train_iter.provide_data,
              label_shapes=[('label', (batch_size,))],
              inputs_need_grad=True)
    modD.init_params(initializer=mx.init.Normal(0.02))
    modD.init_optimizer(
        optimizer='adam',
        optimizer_params={
            'learning_rate': lr,
            'wd': 0.,
            'beta1': beta1,
        })
    mods.append(modD)


    # ============printing==============
    def norm_stat(d):
        return mx.nd.norm(d)/np.sqrt(d.size)
    mon = mx.mon.Monitor(10, norm_stat, pattern=".*output|d1_backward_data", sort=True)
    mon = None
    if mon is not None:
        for mod in mods:
            pass

    def facc(label, pred):
        pred = pred.ravel()
        label = label.ravel()
        return ((pred > 0.5) == label).mean()

    def fentropy(label, pred):
        pred = pred.ravel()
        label = label.ravel()
        return -(label*np.log(pred+1e-12) + (1.-label)*np.log(1.-pred+1e-12)).mean()

    mG = mx.metric.CustomMetric(fentropy)
    mD = mx.metric.CustomMetric(fentropy)
    mACC = mx.metric.CustomMetric(facc)

    print('Training...')
    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')

    # =============train===============
    for epoch in range(100):
        train_iter.reset()
        for t, batch in enumerate(train_iter):
            rbatch = rand_iter.next()

            if mon is not None:
                mon.tic()

            # 首先生成對抗樣本
            modG.forward(rbatch, is_train=True)
            outG = modG.get_outputs()

            # update discriminator on fake
            # 這裏的負樣本label爲0,正樣本label爲1,不像廣泛的mnist同樣。那麼modG就想生成樣本label爲1的,modD要將modG生成的數據斷定爲0
            # train_iter(真實樣本)中的數據斷定爲1。
            label[:] = 0
            modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
            modD.backward()
            #modD.update()
            # 先Copy獲得的對抗樣本的梯度,要注意是複製不是引用。
            gradD = [[grad.copyto(grad.context) for grad in grads] for grads in modD._exec_group.grad_arrays]

            modD.update_metric(mD, [label])
            modD.update_metric(mACC, [label])

            # update discriminator on real
            # 對真實樣本的數據訓練
            label[:] = 1
            batch.label = [label]
            modD.forward(batch, is_train=True)
            modD.backward()
            # 對抗樣本與真實樣本的梯度合到一塊兒建行梯度更新
            for gradsr, gradsf in zip(modD._exec_group.grad_arrays, gradD):
                for gradr, gradf in zip(gradsr, gradsf):
                    gradr += gradf
            modD.update()

            modD.update_metric(mD, [label])
            modD.update_metric(mACC, [label])

            # update generator
            # 更新modG的參數,這裏要注意的是,modG想要生成的樣本label是1的,因此在modD中用了這個label,就是想生成的樣本向label=1靠近。
            # 前向和向後生成輸入數據的梯度diffD
            label[:] = 1
            modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
            modD.backward()
            diffD = modD.get_input_grads()
            # diffD就是modG的loss產生的梯度,用它來向後傳播並更新參數。
            modG.backward(diffD)
            modG.update()

            mG.update([label], modD.get_outputs())


            if mon is not None:
                mon.toc_print()

            t += 1
            if t % 10 == 0:
                print('epoch:', epoch, 'iter:', t, 'metric:', mACC.get(), mG.get(), mD.get())
                mACC.reset()
                mG.reset()
                mD.reset()

                visual('gout', outG[0].asnumpy())
                diff = diffD[0].asnumpy()
                diff = (diff - diff.mean())/diff.std()
                visual('diff', diff)
                visual('data', batch.data[0].asnumpy())

        if check_point:
            print('Saving...')
            modG.save_params('%s_G_%s-%04d.params'%(dataset, stamp, epoch))
            modD.save_params('%s_D_%s-%04d.params'%(dataset, stamp, epoch))

訓練的結果部分結果以下,gout是生成的樣本,data是原樣本,diff是它們的差。能夠從後面生成的gout中看到,結果缺乏一些數字,好比二、3等,這是由於咱們沒有對各個數字的潛在空間進行生成樣本而是用統一的空間,這個統一的空間中對應的數字可能沒有二、3等或者說它們點的比例相對來講比較小,樣例用到的空間只是保證生成樣本是數字,但並不保證每一個數字都會有,若是我保證生成每一個數字的樣本,那麼得從新設計程序,但原理和例程相差不大。

data_gout_diff

圖4 輸出的圖像結果:data是原始數據,gout是G生成的對搞樣本,diff是二者的差。

過程打印的輸出以下:

epoch: 99 iter: 930 metric: ('facc', 1.0) ('fentropy', 8.3449375152587884) ('fentropy', 0.00077932097192388026)

【防止爬蟲轉載而致使的格式問題——連接】:
http://www.cnblogs.com/heguanyou/p/7642608.html

相關文章
相關標籤/搜索