如何簡單易懂地理解變分推斷(variational inference)?

  正在學,把網上優質文章整理了一下。

  我們經常利用貝葉斯公式求posterior distribution P ( Z X ) P(Z | X)

P ( Z X ) = p ( X , Z ) z p ( X , Z = z ) d z P(Z | X)=\frac{p(X, Z)}{\int_{z} p(X, Z=z) d z}

  但posterior distribution P ( Z X ) P(Z | X) 求解用貝葉斯的方法是比較困難的,因爲我們需要去計算 z p ( X = x , Z = z ) d z \int_{z} p(X=x, Z=z) d z ,而 Z Z 通常會是一個高維的隨機變量,這個積分計算起來就非常困難。在貝葉斯統計中,所有的對於未知量的推斷(inference)問題可以看做是對後驗概率(posterior)的計算。因此提出了Variational Inference來計算posterior distribution

  那Variational Inference怎麼做的呢?其核心思想主要包括兩步:

  1. 假設一個分佈 q ( z ; λ ) q(z ; \lambda) (這個分佈是我們搞得定的,搞不定的就沒意義了)
  2. 通過改變分佈的參數 λ \lambda ,使 q ( z ; λ ) q(z ; \lambda) 靠近 p ( z x ) p(z|x)

  總結稱一句話就是,爲真實的後驗分佈引入了一個參數話的模型。 即:用一個簡單的分佈 q ( z ; λ ) q(z ; \lambda) 擬合複雜的分佈 p ( z x ) p(z|x)

  這種策略將計算 p ( z x ) p(z|x) 的問題轉化成優化問題了

λ = arg min λ divergence ( p ( z x ) , q ( z ; λ ) ) \lambda^{*}=\arg \min _{\lambda} \operatorname{divergence}(p(z | x), q(z ; \lambda))

  收斂後,就可以用 q ( z ; λ ) q(z;\lambda) 來代替 p ( z x ) p(z|x) 了。

KL散度

  而用一個分佈去擬合另一個分佈通常需要衡量這兩個分佈之間的相似性,通常採用KL散度,當然還有其他的一些方法,像JS散度這種。下面介紹KL散度

  機器學習中比較重要的一個概念—相對熵(relative entropy)。相對熵又被稱爲KL散度(Kullback–Leibler divergence) 或信息散度 (information divergence),是兩個概率分佈間差異的非對稱性度量 。在信息論中,相對熵等價於兩個概率分佈的信息熵的差值,若其中一個概率分佈爲真實分佈,另一個爲理論(擬合)分佈,則此時相對熵等於交叉熵與真實分佈的信息熵之差,表示使用理論分佈擬合真實分佈時產生的信息損耗 。其公式如下:

D K L ( p q ) = i = 1 N [ p ( x i ) log p ( x i ) p ( x i ) log q ( x i ) ] D_{K L}(p \| q)=\sum_{i=1}^{N}\left[p\left(x_{i}\right) \log p\left(x_{i}\right)-p\left(x_{i}\right) \log q\left(x_{i}\right)\right]

  合併之後表示爲:

D K L ( p q ) = i = 1 N p ( x i ) log ( p ( x i ) q ( x i ) ) D_{K L}(p \| q)=\sum_{i=1}^{N} p\left(x_{i}\right) \log \left(\frac{p\left(x_{i}\right)}{q\left(x_{i}\right)}\right)

  假設理論擬合出來的事件概率分佈 q ( x ) q(x) 跟真實的分佈 p ( x ) p(x) 一模一樣,即 p ( x ) = q ( x ) p(x)=q(x) ,那麼 p ( x i ) log q ( x i ) p\left(x_{i}\right) \log q\left(x_{i}\right) 就等於真實事件的信息熵,這一點顯而易見。在理論擬合出來的事件概率分佈跟真實的一模一樣的時候,相對熵等於0。而擬合出來不太一樣的時候,相對熵大於0。其證明如下:

i = 1 N p ( x i ) log q ( x i ) p ( x i ) i = 1 N p ( x i ) ( q ( x i ) p ( x i ) 1 ) = i = 1 N [ p ( x i ) q ( x i ) ] = 0 \sum_{i=1}^{N} p\left(x_{i}\right) \log \frac{q\left(x_{i}\right)}{p\left(x_{i}\right)} \leq \sum_{i=1}^{N} p\left(x_{i}\right)\left(\frac{q\left(x_{i}\right)}{p\left(x_{i}\right)}-1\right)=\sum_{i=1}^{N}\left[p\left(x_{i}\right)-q\left(x_{i}\right)\right]=0

  其中第一個不等式是由 l n ( x ) x 1 ln(x) \leq x -1 推導出來的,只在 p ( x i ) = q ( x i ) p(x_{i})=q(x_{i}) 時取到等號。

  這個性質很關鍵,因爲它正是深度學習梯度下降法需要的特性。假設神經網絡擬合完美了,那麼它就不再梯度下降,而不完美則因爲它大於0而繼續下降。

  但它有不好的地方,就是它是不對稱的。也就是用 P P 來擬合 Q Q 和用 Q Q 來擬合 P P 的相對熵居然不一樣,而他們的距離是一樣的。這也就是說,相對熵的大小並不跟距離有一一對應的關係。

求解

  中間引入了KL散度,但是我們本文的目的還是來求這個變分推理,不要走偏了。下面涉及一些公式等價轉換:

log P ( x ) = log P ( x , z ) log P ( z x ) = log P ( x , z ) Q ( z ; λ ) log P ( z x ) Q ( z ; λ ) \begin{aligned} \log P(x) &=\log P(x, z)-\log P(z | x) \\ &=\log \frac{P(x, z)}{Q(z ; \lambda)}-\log \frac{P(z | x)}{Q(z ; \lambda)} \end{aligned}

  等式兩邊同時對 Q ( z ) Q(z) 求期望,得:

E q ( z ; λ ) log P ( x ) = E q ( z ; λ ) log P ( x , z ) E q ( z ; λ ) log P ( z x ) log P ( x ) = E q ( z ; λ ) log p ( x , z ) q ( z ; λ ) E q ( z ; λ ) log p ( z x ) q ( z ; λ ) = K L ( q ( z ; λ ) p ( z x ) ) + E q ( z ; λ ) log p ( x , z ) q ( z ; λ ) log P ( x ) = K L ( q ( z ; λ ) p ( z x ) ) + E q ( z ; λ ) log p ( x , z ) q ( z ; λ ) \begin{aligned} \mathbb{E}_{q(z ; \lambda)} \log P(x) &=\mathbb{E}_{q(z ; \lambda)} \log P(x, z)-\mathbb{E}_{q(z ; \lambda)} \log P(z | x) \\ \log P(x) &=\mathbb{E}_{q(z ; \lambda)} \log \frac{p(x, z)}{q(z ; \lambda)}-\mathbb{E}_{q(z ; \lambda)} \log \frac{p(z | x)}{q(z ; \lambda)} \\ &=K L(q(z ; \lambda) \| p(z | x))+\mathbb{E}_{q(z ; \lambda)} \log \frac{p(x, z)}{q(z ; \lambda)} \\ \log P(x) &=K L(q(z ; \lambda) \| p(z | x))+\mathbb{E}_{q(z ; \lambda)} \log \frac{p(x, z)}{q(z ; \lambda)} \end{aligned}

  到這裏我們需要回顧一下我們的問題,我們的目標是使 q ( z ; λ ) q(z;\lambda) 靠近 p ( z x ) p(z|x) ,就是求解:

min λ K L ( q ( z ; λ ) p ( z x ) ) \min_\lambda KL(q(z;\lambda)||p(z|x))

  而由於 K L ( q ( z ; λ ) p ( z x ) ) KL(q(z;\lambda)||p(z|x)) 中包含 p ( z x ) p(z|x) ,這項非常難求。藉助上述公示的推導變形得到的結論:

log P ( x ) = K L ( q ( z ; λ ) p ( z x ) ) + E q ( z ; λ ) log p ( x , z ) q ( z ; λ ) \log P(x) =K L(q(z ; \lambda) \| p(z | x))+\mathbb{E}_{q(z ; \lambda)} \log \frac{p(x, z)}{q(z ; \lambda)}

  將 λ \lambda 看做變量時, log P ( x ) \text{log}P(x) 爲常量,所以, min λ K L ( q ( z ; λ ) p ( z x ) ) \min_\lambda KL(q(z;\lambda)||p(z|x)) 等價於 :

max λ E q ( z ; λ ) log p ( x , z ) q ( z ; λ ) \max_\lambda \mathbb E_{q(z;\lambda)}\text{log}\frac{p(x,z)}{q(z;\lambda)}

  現在,variational inference的目標變成:

max λ E q ( z ; λ ) [ log p ( x , z ) log q ( z ; λ ) ] \max_{\lambda}\mathbb E_{q(z;\lambda)}[\text{log}p(x,z)-\text{log}q(z;\lambda)]

   E q ( z ; λ ) [ log p ( x , z ) log q ( z ; λ ) ] \mathbb E_{q(z;\lambda)}[\text{log}p(x,z)-\text{log}q(z;\lambda)] 稱爲Evidence Lower Bound(ELBO) p ( x ) p(x) 一般被稱之爲evidence,又因爲 K L ( q p ) > = 0 KL(q||p)>=0 , 所以 p ( x ) > = E q ( z ; λ ) [ log p ( x , z ) log q ( z ; λ ) ] p(x)>=E_{q(z;\lambda)}[\text{log}p(x,z)-\text{log}q(z;\lambda)] , 這就是爲什麼被稱爲ELBO

ELBO

  ELBO公式表達爲:

E q ( z ; λ ) [ log p ( x , z ) log q ( z ; λ ) ] \mathbb E_{q(z;\lambda)}[\text{log}p(x,z)-\text{log}q(z;\lambda)]

  原公式可表示爲:

log P ( x ) = K L ( q ( z ; λ ) p ( z x ) ) + E q ( z ; λ ) log p ( x , z ) q ( z ; λ ) \log P(x) =K L(q(z ; \lambda) \| p(z | x))+\mathbb{E}_{q(z ; \lambda)} \log \frac{p(x, z)}{q(z ; \lambda)}

  引入ELBO表示爲:

log P ( x ) = K L ( q ( z ; λ ) p ( z x ) ) + E L B O \log P(x) =K L(q(z ; \lambda) \| p(z | x))+ELBO

  實際上EM算法(Expectation-Maximization)就是利用了這一特徵,它分爲交替進行的兩步:E step假設模型參數不變, q ( z ) = p ( z x ) q(z)=p(z|x) ,計算對數似然率,在M step再做ELBO相對於模型參數的優化。與變分法比較,EM算法假設了當模型參數固定時, p ( z x ) p(z|x) 是易計算的形式,而變分法並無這一限制,對於條件概率難於計算的情況,變分法仍然有效。

  那如何來求解上述公式呢?下面介紹平均場(mean-field)、蒙特卡洛、和黑盒變分推斷 (Black Box Variational Inference) 的方法。

平均場變分族(mean-field variational family)

  之前我們說我們選擇一族合適的近似概率分佈 q ( Z ; λ ) q(Z;\lambda) ,那麼實際問題中,我們可以選擇什麼形式的 q ( Z ; λ ) q(Z;\lambda) 呢?

  一個簡單而有效的變分族爲平均場變分族(mean-field variational family)。它假設了隱藏變量間是相互獨立的:

q ( Z ; λ ) = k = 1 K q k ( Z k ; λ k ) q(Z;\lambda) = \prod_{k=1}^{K}q_k(Z_k;\lambda_k)

  這個假設看起來似乎比較強,但實際應用範圍還是比較廣泛,我們可以將其延展爲將有實際相互關聯的隱藏變量分組,而化爲各組聯合分佈的乘積形式即可。

  利用ELBO和平均場假設,我們就可以利用coordinate ascent variational inference(簡稱CAVI)方法來處理:

  • 利用條件概率分佈的鏈式法則有

p ( z 1 : m , x 1 : n ) = p ( x 1 : n ) j = 1 m p ( z j z 1 : ( j 1 ) , x 1 : n ) p\left(z_{1: m}, x_{1: n}\right)=p\left(x_{1: n}\right) \prod_{j=1}^{m} p\left(z_{j} | z_{1:(j-1)}, x_{1: n}\right)

  • 變分分佈的期望爲

E [ log q ( z 1 : m ) ] = j = 1 m E j [ log q ( z j ) ] E\left[\log q\left(z_{1: m}\right)\right]=\sum_{j=1}^{m} E_{j}\left[\log q\left(z_{j}\right)\right]

  將其代入ELBO的定義得到:

E L B O = logp ( x 1 : n ) + j = 1 m E [ log p ( z j z 1 : ( j 1 ) , x 1 : n ) ] E j [ log q ( z j ) ] E L B O=\operatorname{logp}\left(x_{1: n}\right)+\sum_{j=1}^{m} E\left[\log p\left(z_{j} | z_{1:(j-1)}, x_{1: n}\right)\right]-E_{j}\left[\log q\left(z_{j}\right)\right]

  將其對 z k z_{k} 求導並令導數爲零有:

d E L B O d q ( z k ) = E k [ log p ( z k z k , x ) ] log q ( z k ) 1 = 0 \frac{d E L B O}{d q\left(z_{k}\right)}=E_{-k}\left[\log p\left(z_{k} | z_{-k}, x\right)\right]-\log q\left(z_{k}\right)-1=0

  由此得到coordinate ascent 的更新法則爲:

q ( z k ) exp E k [ log p ( z k , z k , x ) ] q^{*}\left(z_{k}\right) \propto \exp E_{-k}\left[\log p\left(z_{k}, z_{-k}, x\right)\right]

  我們可以利用這一法則不斷的固定其他的 z z 的座標來更新當前的座標對應的 z z 值,這與Gibbs Sampling過程類似,不過Gibbs Sampling是不斷的從條件概率中採樣,而CAVI算法中是不斷的用如下形式更新:

q ( z k ) exp E [ log ( conditional ) ] q^{*}\left(z_{k}\right) \propto \exp E[\log (\text {conditional})]

  其完整算法如下所示:

CAVI算法流程

MCMC

  MCMC方法是利用馬爾科夫鏈取樣來近似後驗概率,變分法是利用優化結果來近似後驗概率,那麼我們什麼時候用MCMC,什麼時候用變分法呢?

  首先,MCMC相較於變分法計算上消耗更大,但是它可以保證取得與目標分佈相同的樣本,而變分法沒有這個保證:它只能尋找到近似於目標分佈一個密度分佈,但同時變分法計算上更快,由於我們將其轉化爲了優化問題,所以可以利用諸如隨機優化(stochastic optimization)或分佈優化(distributed optimization)等方法快速的得到結果。所以當數據量較小時,我們可以用MCMC方法消耗更多的計算力但得到更精確的樣本。當數據量較大時,我們用變分法處理比較合適。

  另一方面,後驗概率的分佈形式也影響着我們的選擇。比如對於有多個峯值的混合模型,MCMC可能只注重其中的一個峯而不能很好的描述其他峯值,而變分法對於此類問題即使樣本量較小也可能優於MCMC方法。

黑盒變分推斷(BBVI)

  ELBO公式表達爲:

E q ( z ; λ ) [ log p ( x , z ) log q ( z ; λ ) ] \mathbb E_{q(z;\lambda)}[\text{log}p(x,z)-\text{log}q(z;\lambda)]

  對用參數 θ \theta 替代 λ \lambda ,並對其求導:

θ ELBO ( θ ) = θ E q ( log p ( x , z ) log q θ ( z ) ) \nabla_{\theta} \operatorname{ELBO}(\theta)=\nabla_{\theta} \mathbb{E}_{q}\left(\log p(x, z)-\log q_{\theta}(z)\right)

  直接展開計算如下:

θ q θ ( z ) ( log p ( x , z ) log q θ ( z ) ) d z = θ [ q θ ( z ) ( log p ( x , z ) log q θ ( z ) ) ] d z = θ ( q θ ( z ) log p ( x , z ) ) θ ( q θ ( z ) log q θ ( z ) ) d z = q θ ( z ) θ log p ( x , z ) q θ ( z ) θ log q θ ( z ) q θ ( z ) θ d z \begin{aligned} & \frac{\partial}{\partial \theta} \int q_{\theta}(z)\left(\log p(x, z)-\log q_{\theta}(z)\right) d z \\ =& \int \frac{\partial}{\partial \theta}\left[q_{\theta}(z)\left(\log p(x, z)-\log q_{\theta}(z)\right)\right] d z \\ =& \int \frac{\partial}{\partial \theta}\left(q_{\theta}(z) \log p(x, z)\right)-\frac{\partial}{\partial \theta}\left(q_{\theta}(z) \log q_{\theta}(z)\right) d z \\ =& \int \frac{\partial q_{\theta}(z)}{\partial \theta} \log p(x, z)-\frac{\partial q_{\theta}(z)}{\partial \theta} \log q_{\theta}(z)-\frac{\partial q_{\theta}(z)}{\partial \theta} d z \end{aligned}

  由於:

q θ ( z ) θ d z = θ q θ ( z ) d z = θ 1 = 0 \int \frac{\partial q_{\theta}(z)}{\partial \theta} d z=\frac{\partial}{\partial \theta} \int q_{\theta}(z) d z=\frac{\partial}{\partial \theta} 1=0

  因此:

θ ELBO ( θ ) = q θ ( z ) θ ( log p ( x , z ) log q θ ( z ) ) d z = q θ ( z ) log q θ ( z ) θ ( log p ( x , z ) log q θ ( z ) ) d z = q θ ( z ) θ log q θ ( z ) ( log p ( x , z ) log q θ ( z ) ) d z = E q [ θ log q θ ( z ) ( log p ( x , z ) log q θ ( z ) ) ] \begin{aligned} \nabla_{\theta} \operatorname{ELBO}(\theta) &=\int \frac{\partial q_{\theta}(z)}{\partial \theta}\left(\log p(x, z)-\log q_{\theta}(z)\right) d z \\ &=\int q_{\theta}(z) \frac{\partial \log q_{\theta}(z)}{\partial \theta}\left(\log p(x, z)-\log q_{\theta}(z)\right) d z \\ &=\int q_{\theta}(z) \nabla_{\theta} \log q_{\theta}(z)\left(\log p(x, z)-\log q_{\theta}(z)\right) d z \\ &=\mathbb{E}_{q}\left[\nabla_{\theta} \log q_{\theta}(z)\left(\log p(x, z)-\log q_{\theta}(z)\right)\right] \end{aligned}

相關文章
相關標籤/搜索