變分自動編碼器大體概念已經理解了快一年多了,用Pytorch寫個模型也是手到擒來的事。但因爲其數學原理仍是沒有搞懂,在看到相關的變體時,總會被數學公式卡住,這對搞學術是致命的。下決心搞懂後,在此記錄下個人理解。函數
這篇文章提出一種擬合數據集分佈的方法,擬合分佈最多見的應用就是生成模型。該方法遵循極大似然策略,即對於數據集$X = \{x^{(i)}\}^N_{i=1}$,對生成模型$p_{\theta}(x)$(注意!這裏的$p_{\theta}(x)$既表明生成模型自己,又表明模型生成數據$x$的邊緣機率,下面相似)完成以下優化:優化
\begin{align}\displaystyle \max\limits_{\theta}L = \sum\limits_{i=1}^N\log p_{\theta}(x^{(i)})\end{align}編碼
可是,模型不可能憑空產生數據,必需要有輸入纔能有輸出,因此做者對數據的生成過程進行了假設,假設數據集$X = \{x^{(i)}\}^N_{i=1}$的生成過程由如下兩步組成:blog
一、經過某個先驗分佈$p_{\theta^*}(z)$,抽樣得到隱變量$z^{(i)}$ 。數學
二、再經過某個條件分佈$p_{\theta^*}(x|z)$,抽樣生成$x^{(i)}$。it
顯然,以上參數$\theta^*$的值、個數,甚至是計算推演的過程,都是未知的。爲了使優化得以進行,就要對參數的個數和計算流程進行約束,一般專家會根據經驗來給予特定數據集特定的計算過程。不失通常性,做者假設先驗分佈$p_{\theta^*}(z)$和似然函數$p_{\theta^*}(x|z)$來自於參數族$p_{\theta}(z)$和$p_{\theta}(x|z)$,而且它們的機率分佈函數(PDFs)幾乎到處可微。變量
儘管有如上假設,因爲$x^{(i)}$與$z$之間的關係未知,咱們不能直接使用梯度降低等方式對$p_{\theta}(x|z)$進行擬合。做者採用一種迂迴的方式,讓模型本身學會$z$與$x$之間的關係。做者使用自動編碼器的機制,與生成模型同時訓練一個後驗分佈模型$q_{\phi}(z|x)$,用來模擬其後驗分佈$p_{\theta}(z|x)$,稱爲編碼器,並稱$p_{\theta}(x|z)$爲解碼器。機率圖以下:原理
其中$\theta$表示生成模型$p_{\theta}(z)p_{\theta}(x|z)$的待優化參數,$\phi$表示用於估計$p_{\theta}(z|x)$的模型的參數。重構
有了$q_\phi(z|x)$做爲輔助後,針對每一數據集樣本,待優化式可轉換以下:方法
\begin{align} &\log p_{\theta}(x)\\ =& \text{E}_{q_{\phi}(z|x)}\log p_{\theta}(x)\\ =& \text{E}_{q_{\phi}(z|x)}\log \frac{p_{\theta}(x,z)}{p_{\theta}(z|x)}\\ =& \text{E}_{q_{\phi}(z|x)}\left[ \log p_{\theta}(x,z) - \log p_{\theta}(z|x)+\log q_{\phi}(z|x)-\log q_{\phi}(z|x) \right]\\ =& \text{E}_{q_{\phi}(z|x)}\left[ \log \frac{q_{\phi}(z|x)}{ p_{\theta}(z|x)} + \log p_{\theta}(x,z) -\log q_{\phi}(z|x) \right]\\ =& \text{KL} \left[ q_{\phi}(z|x) || p_{\theta}(z|x) \right]+ \text{E}_{q_{\phi}(z|x)}\left[ \log p_{\theta}(x,z) -\log q_{\phi}(z|x) \right] \\ =& \text{KL} \left[ q_{\phi}(z|x) || p_{\theta}(z|x) \right]+ \mathcal{L}(\theta,\phi) \\ \end{align}
容易看出,因爲$(8)$式第一項是相對熵非負,第二項即爲待優化式$(2)$的下界,稱之爲變分下界。所以,咱們只需對第二項進行優化,原式天然變大。將$\mathcal{L}(\theta,\phi)$進行變換以下:
\begin{align} \mathcal{L}(\theta,\phi) &= \text{E}_{q_{\phi}(z|x)}\left[ \log p_{\theta}(x,z) -\log q_{\phi}(z|x) + \log_{\theta}(z) - \log_{\theta}(z) \right] \\ &= \text{E}_{q_{\phi}(z|x)}\left[ -\log q_{\phi}(z|x) + \log_{\theta}(z) + \log p_{\theta}(x,z) - \log_{\theta}(z) \right] \\ &= \text{E}_{q_{\phi}(z|x)}\left[ -\log\frac{q_{\phi}(z|x)}{p_{\theta}(z)} + \log \frac{p_{\theta}(x,z)}{p_{\theta}(z)} \right] \\ &= -\text{KL}\left[q_{\phi}(z|x)||p_{\theta}(z)\right]+ \text{E}_{q_{\phi}(z|x)}\left[ \log p_{\theta}(x|z) \right] \\ \end{align}
對於以上兩項,咱們能夠把第一項理解爲正則項,也就是說擬合的後驗分佈應該和生成模型先驗分佈比較接近纔好;第二項理解爲重構損失,就是自編碼器的損失。同時對這兩項進行優化,就可使生成模型向目標靠近。
在以上優化式中包含隨機採樣過程,做者提出使用重參數化來創建可反向傳播的採樣。實際上就是給模型額外添加一個已知的隨機變量做爲輸入,從而使模型的抽樣過程可微。將做者於文中對以上推導的舉例——變分自動編碼器(VAE),拿來與$(12)$式作對比,推導的式子以及重參數化的意義就一目瞭然了。其中VAE的$p_{\theta}(z)$定義爲相互獨立的多維高斯分佈。