EM算法說起來很簡單,給定一個要估計的參數的初值,計算隱含變量分佈,再根據隱含變量的分佈更新要估計的參數值,之後在這兩個步驟之間進行迭代。但是其中的數學原理,GMM的推導等等其實並不簡單,難想更難算。這篇博客主要基於翻譯我看過的好材料,對其中做出些許的解釋。以下便從最簡單的例子說起
投硬幣的例子
出自http://www.cmi.ac.in/~madhavan/courses/datamining12/reading/em-tutorial.pdf
EM算法實現的是在數據不完全的情況下的參數預測。我們用一個投硬幣的例子來解釋EM算法的流程。假設我們有A,B兩枚硬幣,其正面朝上的概率分別爲
θA,θB
,這兩個參數即爲需要估計的參數。我們設計5組實驗,每次實驗投擲10次硬幣(但不知道用哪一枚硬幣進行這次實驗),投擲結束後會得到一個數組
x=(x1,x2,...,x5)
,來表示每組實驗有幾次硬幣是正面朝上的,因此
0≤xi≤10
。
如果我們知道每一組實驗中的
xi
是A硬幣投擲的結果還是B硬幣的結果,我們就很容易估計出
θA,θB
,只需要統計在所有的試驗中兩個硬幣分別有幾次是正面朝上的,除以他們各自投擲的總次數。數據不完全的意思在於,我們並不知道每一個數據是哪一個硬幣產生的。EM算法就是適用於這種問題。
雖然我們不知道每組實驗用的是哪一枚硬幣,但如果我們用某種方法猜測每組實驗是哪個硬幣投擲的,我們就可以將數據缺失的估計問題轉化成一個最大似然問題+完整參數估計問題。
我們將逐步講解投硬幣的例子。假設5次試驗的結果如下(H是正面,T是反面):
試驗序號 |
結果 |
1 |
H T T T H H T H T H |
2 |
H H H H T H H H H H |
3 |
H T H H H H H T H H |
4 |
H T H T T T H H T T |
5 |
T H H H T H H H T H |
首先,隨機選取初值
θA,θB
,比如
θA=0.6,θB=0.5
。EM算法的E步驟,是計算在當前的預估參數下,隱含變量(是A硬幣還是B硬幣)的每個值出現的概率。也就是給定
θA,θB
和觀測數據,計算這組數據出自A硬幣的概率和這組數據出自B硬幣的概率。對於第一組實驗,5正面5背面。
A硬幣得到這個結果的概率爲
0.65×0.45=0.000796
B硬幣得到這個結果的概率爲
0.55×0.55=0.000977
因此,第一組實驗是A硬幣得到的概率爲
0.000796/(0.000796+0.000977)=0.45
,第一組實驗是B硬幣得到的概率爲
0.000977/(0.000796+0.000977)=0.55
。整個5組實驗的A,B投擲概率如下:
試驗序號 |
是A硬幣概率 |
是B硬幣概率 |
1 |
0.45 |
0.55 |
2 |
0.80 |
0.20 |
3 |
0.73 |
0.27 |
4 |
0.35 |
0.65 |
5 |
0.65 |
0.35 |
根據隱含變量的概率,可以計算出兩組訓練值的期望。依然以第一組實驗來舉例子:5正5反中,A硬幣投擲出了
0.45×5=2.2
個正面和
0.45×5=2.2
個反面;B硬幣投擲出了
0.55×5=2.8
個正面和
0.55×5=2.8
個反面。整個5組實驗的期望如下表:
試驗序號 |
A硬幣 |
B硬幣 |
1 |
2.2H, 2.2T |
2.8H, 2.8T |
2 |
7.2H, 0.8T |
1.8H, 0.2T |
3 |
5.9H, 1.5T |
2.1H, 0.5T |
4 |
1.4H, 2.1T |
2.6H, 3.9T |
5 |
4.5H, 1.9T |
2.5H, 1.1T |
SUM |
21.3H, 8.6T |
11.7H, 8.4T |
通過計算期望,我們把一個有隱含變量的問題變化成了一個沒有隱含變量的問題,由上表的數據,估計
θA,θB
變得非常簡單。
θA=21.3/(21.3+8.6)=0.71
θB=11.7/(11.7+8.4)=0.58
下圖是原文中以上描述的示意圖
當我們有了新的估計,便可以基於這個估計進行下一次迭代了。綜上所述,EM算法的步驟是:
1. E步驟:根據觀測值計算隱含變量的分佈情況
2. M步驟:根據隱含變量的分佈來估計新的模型參數
GMM的參數推導
總體思想來自PRML chapter 9.2
高斯混合模型是什麼這裏不再贅述。書上的公式相當簡潔,當然多元高斯函數對於均值和方差求導你可以不會,然而這是一個練習矩陣求導的好機會,畢竟好久沒有推過這麼複雜的公式了;再者,關於這部分的求導細節網絡上的資料很少。以下就分享一下我的推導過程。
根據極大似然的思想,在已知GMM模型產生的一系列數據點
x1,x2,...xn
(假定它們是列向量)時,我們需要知道一組最佳的參數
μ1,μ2,...μk
,
Σ1,Σ2,...Σk
,和
π1,π2,...πk
,在這種參數下生成這組數據點的可能性最大。求解GMM模型的參數,就是求以下的極大似然函數的極值點。
lnp(X|π,μ,Σ)=∑n=1Nln∑k=1KπkN(xn|μk,Σk)(1.1)
其中,多元高斯函數的公式爲
N(xn|μk,Σk)=12πD/2|Σk|1/2exp(−12(xn−μk)TΣ−1k(xn−μk))(1.2)
我們的最終目的是對公式
(1.1)
進行對
μk,Σk,πk
求導,並求導數爲零時它們分別對應的值。在對這個終極公式求導之前,爲了描述的更清楚,我們先計算公式
(1.2)
對
μk,Σk
的導數。
ddμkN(xn|μk,Σk)=12πD/2|Σk|1/2exp(−12(xn−μk)TΣ−1k(xn−μk))ddμk(−12(xn−μk)TΣ−1k(xn−μk))=N(xn|μk,Σk)ddμk(−12(xn−μk)TΣ−1k(xn−μk))=N(xn|μk,Σk)dd(xn−μk)(−12(xn−μk)TΣ−1k(xn−μk))ddμk(xn−μk)=N(xn|μk,Σk)(−Σ−1k(xn−μk))(−1)=N(xn|μk,Σk)Σ−1k(xn−μk)
這裏,
−12(xn−μk)TΣ−1k(xn−μk)
對於
xn−μk
的求導原理如下(包括一個簡單的變量代換):
ddxxTAx=2Ax,當A爲對稱矩陣
公式來源是https://en.wikipedia.org/wiki/Matrix_calculus
再計算
N(xn|μk,Σk)
對協方差的求導
ddΣkN(xn|μk,Σk)=12πD/2{d|Σk|−1/2dΣkexp(−12(xn−μk)TΣ−1k(xn−μk))+dexp(−12(xn−μk)TΣ−1(xn−μk))dΣk|Σ|−1/2}=12πD/2{−12|Σk|−32|Σk|(Σ−1k)Texp(−12(xn−μk)TΣ−1k(xn−μk))+12Σ−Tk(xn−μk)(xn−μk)TΣ−Tk|Σk|−1/2}=12πD/2|Σk|−1/2exp(−12(xn−μk)TΣ−1k(xn−μk)){−12(Σ−1k)T+12Σ−Tk(xn−μk)(xn−μk)TΣ−Tk}=N(xn|μk,Σk){−12(Σ−1k)T+12Σ−Tk(xn−μk)(xn−μk)TΣ−Tk}
這裏求導的重點有兩個,對行列式的求導公式和對逆矩陣trace的求導公式
首先,對行列式的求導公式爲
d|X|dX=|X|(X−1)T
這個公式同樣出自
https://en.wikipedia.org/wiki/Matrix_calculus
因此,
d|Σk|−1/2dΣk=−12|Σk|−32|Σk|(Σ−1k)T
接下來,對矩陣的trace的求導公式
ddXTr(AX−1B)=−X−TATBTX−T
這個公式出自
http://www2.imm.dtu.dk/pubdb/views/edoc_download.php/3274/pdf/imm3274.pdf
又因爲
12(xn−μk)TΣ−1(xn−μk)
其實是一個實數,因此它等於它的trace,因此
d(−12(xn−μk)TΣ−1(xn−μk))dΣk=dtr(−12(xn−μk)TΣ−1k(xn−μk))dΣk=12Σ−Tk(xn−μk)(xn−μk)TΣ−Tk
推完了一個高斯函數對其均值和方差的求導,我們開始進入主題:對極大似然函數對均值和方差求導
首先,對均值求導:
ddμklnp(X|π,μ,Σ)=∑n=1N1∑Kj=1πjN(xn|μj,Σj)ddμkπkN(xn|μk,Σk)=∑n=1N1∑Kj=1πjN(xn|μj,Σj)πkN(xn|μk,Σk)Σ−1k(xn−μk)=∑n=1NπkN(xn|μk,Σk)∑Kj=1πjN(xn|μj,Σj)Σ−1k(xn−μk)
爲了表達的方便,我們令
γ(znk)=πkN(xn|μk,Σk)∑Kj=1πjN(xn|μj,Σj)
,
Nk=∑Nn=1γ(znk)
則有:
ddμklnp(X|π,μ,Σ)=∑n=1Nγ(znk)Σ−1k(xn−μk)
我們讓這個式子等於0,即
∑n=1Nγ(znk)Σ−1k(xn−μk)=0
可以得到
μk=1Nk∑n=1Nγ(znk)xn
終於我們看到書上的結果了!觀察一下,這個結果其實很容易想象。
γ(znk)
的實際含義是第n個觀測數據分別屬於第1,2,…,k個高斯函數的概率。每一個高斯函數的均值,將會是觀測數據在用各個高斯函數上的概率加權後的計算。
現在我們再對方差求導。
ddΣklnp(X|π,μ,Σ)=∑
ddΣklnp(X|π,μ,Σ)=∑n=1N1∑Kj=1