EM算法的學習筆記

EM算法說起來很簡單,給定一個要估計的參數的初值,計算隱含變量分佈,再根據隱含變量的分佈更新要估計的參數值,之後在這兩個步驟之間進行迭代。但是其中的數學原理,GMM的推導等等其實並不簡單,難想更難算。這篇博客主要基於翻譯我看過的好材料,對其中做出些許的解釋。以下便從最簡單的例子說起

投硬幣的例子

出自http://www.cmi.ac.in/~madhavan/courses/datamining12/reading/em-tutorial.pdf

EM算法實現的是在數據不完全的情況下的參數預測。我們用一個投硬幣的例子來解釋EM算法的流程。假設我們有A,B兩枚硬幣,其正面朝上的概率分別爲 θA,θB ,這兩個參數即爲需要估計的參數。我們設計5組實驗,每次實驗投擲10次硬幣(但不知道用哪一枚硬幣進行這次實驗),投擲結束後會得到一個數組 x=(x1,x2,...,x5) ,來表示每組實驗有幾次硬幣是正面朝上的,因此 0xi10
如果我們知道每一組實驗中的 xi 是A硬幣投擲的結果還是B硬幣的結果,我們就很容易估計出 θA,θB ,只需要統計在所有的試驗中兩個硬幣分別有幾次是正面朝上的,除以他們各自投擲的總次數。數據不完全的意思在於,我們並不知道每一個數據是哪一個硬幣產生的。EM算法就是適用於這種問題。
雖然我們不知道每組實驗用的是哪一枚硬幣,但如果我們用某種方法猜測每組實驗是哪個硬幣投擲的,我們就可以將數據缺失的估計問題轉化成一個最大似然問題+完整參數估計問題
我們將逐步講解投硬幣的例子。假設5次試驗的結果如下(H是正面,T是反面):

試驗序號 結果
1 H T T T H H T H T H
2 H H H H T H H H H H
3 H T H H H H H T H H
4 H T H T T T H H T T
5 T H H H T H H H T H

首先,隨機選取初值 θA,θB ,比如 θA=0.6,θB=0.5 。EM算法的E步驟,是計算在當前的預估參數下,隱含變量(是A硬幣還是B硬幣)的每個值出現的概率。也就是給定 θA,θB 和觀測數據,計算這組數據出自A硬幣的概率和這組數據出自B硬幣的概率。對於第一組實驗,5正面5背面。

A硬幣得到這個結果的概率爲 0.65×0.45=0.000796
B硬幣得到這個結果的概率爲 0.55×0.55=0.000977

因此,第一組實驗是A硬幣得到的概率爲 0.000796/(0.000796+0.000977)=0.45 ,第一組實驗是B硬幣得到的概率爲 0.000977/(0.000796+0.000977)=0.55 。整個5組實驗的A,B投擲概率如下:

試驗序號 是A硬幣概率 是B硬幣概率
1 0.45 0.55
2 0.80 0.20
3 0.73 0.27
4 0.35 0.65
5 0.65 0.35

根據隱含變量的概率,可以計算出兩組訓練值的期望。依然以第一組實驗來舉例子:5正5反中,A硬幣投擲出了 0.45×5=2.2 個正面和 0.45×5=2.2 個反面;B硬幣投擲出了 0.55×5=2.8 個正面和 0.55×5=2.8 個反面。整個5組實驗的期望如下表:

試驗序號 A硬幣 B硬幣
1 2.2H, 2.2T 2.8H, 2.8T
2 7.2H, 0.8T 1.8H, 0.2T
3 5.9H, 1.5T 2.1H, 0.5T
4 1.4H, 2.1T 2.6H, 3.9T
5 4.5H, 1.9T 2.5H, 1.1T
SUM 21.3H, 8.6T 11.7H, 8.4T

通過計算期望,我們把一個有隱含變量的問題變化成了一個沒有隱含變量的問題,由上表的數據,估計 θA,θB 變得非常簡單。

θA=21.3/(21.3+8.6)=0.71

θB=11.7/(11.7+8.4)=0.58

下圖是原文中以上描述的示意圖
原文中的示意圖

當我們有了新的估計,便可以基於這個估計進行下一次迭代了。綜上所述,EM算法的步驟是:
1. E步驟:根據觀測值計算隱含變量的分佈情況
2. M步驟:根據隱含變量的分佈來估計新的模型參數

GMM的參數推導

總體思想來自PRML chapter 9.2

高斯混合模型是什麼這裏不再贅述。書上的公式相當簡潔,當然多元高斯函數對於均值和方差求導你可以不會,然而這是一個練習矩陣求導的好機會,畢竟好久沒有推過這麼複雜的公式了;再者,關於這部分的求導細節網絡上的資料很少。以下就分享一下我的推導過程。

根據極大似然的思想,在已知GMM模型產生的一系列數據點 x1,x2,...xn (假定它們是列向量)時,我們需要知道一組最佳的參數 μ1,μ2,...μk Σ1,Σ2,...Σk ,和 π1,π2,...πk ,在這種參數下生成這組數據點的可能性最大。求解GMM模型的參數,就是求以下的極大似然函數的極值點。

lnp(X|π,μ,Σ)=n=1Nlnk=1KπkN(xn|μk,Σk)(1.1)

其中,多元高斯函數的公式爲

N(xn|μk,Σk)=12πD/2|Σk|1/2exp(12(xnμk)TΣ1k(xnμk))(1.2)

我們的最終目的是對公式 (1.1) 進行對 μk,Σk,πk 求導,並求導數爲零時它們分別對應的值。在對這個終極公式求導之前,爲了描述的更清楚,我們先計算公式 (1.2) μk,Σk 的導數。

ddμkN(xn|μk,Σk)=12πD/2|Σk|1/2exp(12(xnμk)TΣ1k(xnμk))ddμk(12(xnμk)TΣ1k(xnμk))=N(xn|μk,Σk)ddμk(12(xnμk)TΣ1k(xnμk))=N(xn|μk,Σk)dd(xnμk)(12(xnμk)TΣ1k(xnμk))ddμk(xnμk)=N(xn|μk,Σk)(Σ1k(xnμk))(1)=N(xn|μk,Σk)Σ1k(xnμk)

這裏, 12(xnμk)TΣ1k(xnμk) 對於 xnμk 的求導原理如下(包括一個簡單的變量代換):

ddxxTAx=2Ax,A
公式來源是https://en.wikipedia.org/wiki/Matrix_calculus

再計算 N(xn|μk,Σk) 對協方差的求導

ddΣkN(xn|μk,Σk)=12πD/2{d|Σk|1/2dΣkexp(12(xnμk)TΣ1k(xnμk))+dexp(12(xnμk)TΣ1(xnμk))dΣk|Σ|1/2}=12πD/2{12|Σk|32|Σk|(Σ1k)Texp(12(xnμk)TΣ1k(xnμk))+12ΣTk(xnμk)(xnμk)TΣTk|Σk|1/2}=12πD/2|Σk|1/2exp(12(xnμk)TΣ1k(xnμk)){12(Σ1k)T+12ΣTk(xnμk)(xnμk)TΣTk}=N(xn|μk,Σk){12(Σ1k)T+12ΣTk(xnμk)(xnμk)TΣTk}

這裏求導的重點有兩個,對行列式的求導公式和對逆矩陣trace的求導公式
首先,對行列式的求導公式爲

d|X|dX=|X|(X1)T
這個公式同樣出自 https://en.wikipedia.org/wiki/Matrix_calculus
因此,
d|Σk|1/2dΣk=12|Σk|32|Σk|(Σ1k)T

接下來,對矩陣的trace的求導公式
ddXTr(AX1B)=XTATBTXT

這個公式出自 http://www2.imm.dtu.dk/pubdb/views/edoc_download.php/3274/pdf/imm3274.pdf
又因爲 12(xnμk)TΣ1(xnμk) 其實是一個實數,因此它等於它的trace,因此
d(12(xnμk)TΣ1(xnμk))dΣk=dtr(12(xnμk)TΣ1k(xnμk))dΣk=12ΣTk(xnμk)(xnμk)TΣTk

推完了一個高斯函數對其均值和方差的求導,我們開始進入主題:對極大似然函數對均值和方差求導

首先,對均值求導:

ddμklnp(X|π,μ,Σ)=n=1N1Kj=1πjN(xn|μj,Σj)ddμkπkN(xn|μk,Σk)=n=1N1Kj=1πjN(xn|μj,Σj)πkN(xn|μk,Σk)Σ1k(xnμk)=n=1NπkN(xn|μk,Σk)Kj=1πjN(xn|μj,Σj)Σ1k(xnμk)

爲了表達的方便,我們令 γ(znk)=πkN(xn|μk,Σk)Kj=1πjN(xn|μj,Σj) , Nk=Nn=1γ(znk) 則有:
ddμklnp(X|π,μ,Σ)=n=1Nγ(znk)Σ1k(xnμk)

我們讓這個式子等於0,即
n=1Nγ(znk)Σ1k(xnμk)=0

可以得到
μk=1Nkn=1Nγ(znk)xn

終於我們看到書上的結果了!觀察一下,這個結果其實很容易想象。 γ(znk) 的實際含義是第n個觀測數據分別屬於第1,2,…,k個高斯函數的概率。每一個高斯函數的均值,將會是觀測數據在用各個高斯函數上的概率加權後的計算。

現在我們再對方差求導。

ddΣklnp(X|π,μ,Σ)= ddΣklnp(X|π,μ,Σ)=n=1N1Kj=1
相關文章
相關標籤/搜索