GAN半監督學習

概述

GAN的發明者Ian Goodfellow2016年在Open AI任職期間發表了這篇論文,其中提到了GAN用於半監督學習(semi supervised)的方法。稱爲SSGAN。 
作者給出了Theano+Lasagne實現。本文結合源碼對這種方法的推導和實現進行講解。1

半監督學習

考慮一個分類問題。 
如果訓練集中大部分樣本沒有標記類別,只有少部分樣本有標記。則需要用半監督學習(semi-supervised)方法來訓練一個分類器。

wiki上的這張圖很好地說明了無標記樣本在半監督學習中發揮作用: 
這裏寫圖片描述

如果只考慮有標記樣本(黑白點),純粹使用監督學習。則得到垂直的分類面。 
考慮了無標記樣本(灰色點)之後,我們對樣本的整體分佈有了進一步認識,能夠得到新的、更準確的分類面。

核心理念

在半監督學習中運用GAN的邏輯如下。

  • 無標記樣本沒有類別信息,無法訓練分類器;
  • 引入GAN後,其中生成器(Generator)可以從隨機信號生成僞樣本;
  • 相比之下,原有的無標記樣本擁有了人造類別:真。可以和僞樣本一起訓練分類器。 
    這裏寫圖片描述

舉個通俗的例子:就算沒人教認字,多練練分辨「是不是字」也對認字有好處。有粗糙的反饋,也比沒有反饋強。

原理

框架

GAN中的兩個核心模塊是生成器(Generator)和鑑別器(Discriminator)。這裏用分類器(Classifier)代替了鑑別器。 
這裏寫圖片描述

訓練集中包含有標籤樣本xlxl和無標籤樣本xuxu。 
生成器從隨機噪聲生成僞樣本IfIf。 
分類器接受樣本II,對於KK類分類問題,輸出K+1K+1維估計ll,再經過softmax函數得到概率pp:其前KK維對應原有KK個類,最後一維對應「僞樣本」類。 
pp的最大值位置對應爲估計標籤yy

softmax(xi)=exp(xi)jexp(xj)softmax(xi)=exp⁡(xi)∑jexp⁡(xj)

三種誤差

整個系統涉及三種誤差。

對於訓練集中的有標籤樣本,考察估計的標籤是否正確。即,計算分類爲相應的概率: 

Llabel=E[lnp(y|x)]Llabel=−E[ln⁡p(y|x)]

對於訓練集中的無標籤樣本,考察是否估計爲「真」。即,計算不估計爲K+1K+1類的概率: 

Lunlabel=E[ln(1p(K+1|x))]Lunlabel=−E[ln⁡(1−p(K+1|x))]

對於生成器產生的僞樣本,考察是否估計爲「僞」。即,計算估計爲K+1K+1類的概率: 

Lfake=E[lnp(K+1|x)]Lfake=−E[ln⁡p(K+1|x)]

推導

考慮softmax函數的一個特性: 

softmax(xic)=exp(xic)jexp(xjc)=exp(xi)/exp(c)jexp(xj)/exp(c)=softmax(xi)softmax(xi−c)=exp⁡(xi−c)∑jexp⁡(xj−c)=exp⁡(xi)/exp(c)∑jexp⁡(xj)/exp⁡(c)=softmax(xi)

即,如果輸入各維減去同一個數,softmax結果不變。 
於是,可以令 lllK+1l→l−lK+1 ,有 lK+1=0lK+1=0 p=softmax(l)p=softmax(l) 保持不變。

期望號略去不寫,利用explK+1=1exp⁡lK+1=1,後兩種代價變爲: 

Lunlabel=ln[1p(K+1|x)]=ln[Kj=1expljKj=1explj+explK+1]=ln[j=1Kexplj]+ln[1+j=1Kexplj]Lunlabel=−ln⁡[1−p(K+1|x)]=−ln⁡[∑j=1Kexp⁡lj∑j=1Kexp⁡lj+exp⁡lK+1]=−ln⁡[∑j=1Kexp⁡lj]+ln⁡[1+∑j=1Kexp⁡lj]

Lfake=ln[p(K+1|x)]=ln[1+j=1Kexplj]Lfake=−ln⁡[p(K+1|x)]=ln⁡[1+∑j=1Kexp⁡lj]

上述推導可以讓我們省去lK+1lK+1讓分類器仍然輸出K維的估計ll

對於第一個代價,由於分類器輸入必定來自前K類,所以可以直接使用ll的前K維: 

Llabel=ln[p(y|x,y<K+1)]=ln[explyKj=1explj]=ly+ln[j=1Kexplj]Llabel=−ln⁡[p(y|x,y<K+1)]=−ln⁡[exp⁡ly∑j=1Kexp⁡lj]=−ly+ln⁡[∑j=1Kexp⁡lj]

引入兩個函數,使得書寫更爲簡潔:

LSE(x)=ln[j=1expxj]LSE(x)=ln⁡[∑j=1exp⁡xj]

softplus(x)=ln(1+expx)softplus(x)=ln⁡(1+exp⁡x)

三個誤差: 

Llabel=ly+LSE(l)Llabel=−ly+LSE(l)

Lunlabel=LSE(l)+softplus(LSE(l))Lunlabel=−LSE(l)+softplus(LSE(l))

Lfake=softplus(LSE(l))Lfake=softplus(LSE(l))

優化目標

對於分類器來說,希望上述誤差儘量小。引入權重ww,得到分類器優化目標: 

w w,得到 分類器優化目標: 

LD=Llabel+w2(Lunlabe2(Lunlabel+Lfake)LD=Llabel+w2(Lunlabel+Lfake)

對於生成器來說,希望其輸出的僞樣本能夠騙過分類器。生成器優化目標與分類器的第三項相反: 

L l abel
相關文章
相關標籤/搜索