GMM算法的matlab程序

GMM算法的matlab程序

在「GMM算法的matlab程序(初步)」這篇文章中已經用matlab程序對iris數據庫進行簡單的實現,下面的程序最終的目的是求準確度。html

做者:凱魯嘎吉 - 博客園 http://www.cnblogs.com/kailugaji/git

1.採用iris數據庫

iris_data.txt

5.1    3.5    1.4    0.2
4.9    3    1.4    0.2
4.7    3.2    1.3    0.2
4.6    3.1    1.5    0.2
5    3.6    1.4    0.2
5.4    3.9    1.7    0.4
4.6    3.4    1.4    0.3
5    3.4    1.5    0.2
4.4    2.9    1.4    0.2
4.9    3.1    1.5    0.1
5.4    3.7    1.5    0.2
4.8    3.4    1.6    0.2
4.8    3    1.4    0.1
4.3    3    1.1    0.1
5.8    4    1.2    0.2
5.7    4.4    1.5    0.4
5.4    3.9    1.3    0.4
5.1    3.5    1.4    0.3
5.7    3.8    1.7    0.3
5.1    3.8    1.5    0.3
5.4    3.4    1.7    0.2
5.1    3.7    1.5    0.4
4.6    3.6    1    0.2
5.1    3.3    1.7    0.5
4.8    3.4    1.9    0.2
5    3    1.6    0.2
5    3.4    1.6    0.4
5.2    3.5    1.5    0.2
5.2    3.4    1.4    0.2
4.7    3.2    1.6    0.2
4.8    3.1    1.6    0.2
5.4    3.4    1.5    0.4
5.2    4.1    1.5    0.1
5.5    4.2    1.4    0.2
4.9    3.1    1.5    0.2
5    3.2    1.2    0.2
5.5    3.5    1.3    0.2
4.9    3.6    1.4    0.1
4.4    3    1.3    0.2
5.1    3.4    1.5    0.2
5    3.5    1.3    0.3
4.5    2.3    1.3    0.3
4.4    3.2    1.3    0.2
5    3.5    1.6    0.6
5.1    3.8    1.9    0.4
4.8    3    1.4    0.3
5.1    3.8    1.6    0.2
4.6    3.2    1.4    0.2
5.3    3.7    1.5    0.2
5    3.3    1.4    0.2
7    3.2    4.7    1.4
6.4    3.2    4.5    1.5
6.9    3.1    4.9    1.5
5.5    2.3    4    1.3
6.5    2.8    4.6    1.5
5.7    2.8    4.5    1.3
6.3    3.3    4.7    1.6
4.9    2.4    3.3    1
6.6    2.9    4.6    1.3
5.2    2.7    3.9    1.4
5    2    3.5    1
5.9    3    4.2    1.5
6    2.2    4    1
6.1    2.9    4.7    1.4
5.6    2.9    3.6    1.3
6.7    3.1    4.4    1.4
5.6    3    4.5    1.5
5.8    2.7    4.1    1
6.2    2.2    4.5    1.5
5.6    2.5    3.9    1.1
5.9    3.2    4.8    1.8
6.1    2.8    4    1.3
6.3    2.5    4.9    1.5
6.1    2.8    4.7    1.2
6.4    2.9    4.3    1.3
6.6    3    4.4    1.4
6.8    2.8    4.8    1.4
6.7    3    5    1.7
6    2.9    4.5    1.5
5.7    2.6    3.5    1
5.5    2.4    3.8    1.1
5.5    2.4    3.7    1
5.8    2.7    3.9    1.2
6    2.7    5.1    1.6
5.4    3    4.5    1.5
6    3.4    4.5    1.6
6.7    3.1    4.7    1.5
6.3    2.3    4.4    1.3
5.6    3    4.1    1.3
5.5    2.5    4    1.3
5.5    2.6    4.4    1.2
6.1    3    4.6    1.4
5.8    2.6    4    1.2
5    2.3    3.3    1
5.6    2.7    4.2    1.3
5.7    3    4.2    1.2
5.7    2.9    4.2    1.3
6.2    2.9    4.3    1.3
5.1    2.5    3    1.1
5.7    2.8    4.1    1.3
6.3    3.3    6    2.5
5.8    2.7    5.1    1.9
7.1    3    5.9    2.1
6.3    2.9    5.6    1.8
6.5    3    5.8    2.2
7.6    3    6.6    2.1
4.9    2.5    4.5    1.7
7.3    2.9    6.3    1.8
6.7    2.5    5.8    1.8
7.2    3.6    6.1    2.5
6.5    3.2    5.1    2
6.4    2.7    5.3    1.9
6.8    3    5.5    2.1
5.7    2.5    5    2
5.8    2.8    5.1    2.4
6.4    3.2    5.3    2.3
6.5    3    5.5    1.8
7.7    3.8    6.7    2.2
7.7    2.6    6.9    2.3
6    2.2    5    1.5
6.9    3.2    5.7    2.3
5.6    2.8    4.9    2
7.7    2.8    6.7    2
6.3    2.7    4.9    1.8
6.7    3.3    5.7    2.1
7.2    3.2    6    1.8
6.2    2.8    4.8    1.8
6.1    3    4.9    1.8
6.4    2.8    5.6    2.1
7.2    3    5.8    1.6
7.4    2.8    6.1    1.9
7.9    3.8    6.4    2
6.4    2.8    5.6    2.2
6.3    2.8    5.1    1.5
6.1    2.6    5.6    1.4
7.7    3    6.1    2.3
6.3    3.4    5.6    2.4
6.4    3.1    5.5    1.8
6    3    4.8    1.8
6.9    3.1    5.4    2.1
6.7    3.1    5.6    2.4
6.9    3.1    5.1    2.3
5.8    2.7    5.1    1.9
6.8    3.2    5.9    2.3
6.7    3.3    5.7    2.5
6.7    3    5.2    2.3
6.3    2.5    5    1.9
6.5    3    5.2    2
6.2    3.4    5.4    2.3
5.9    3    5.1    1.8
View Code

iris_id.txt

0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
View Code

2.matlab程序

My_GMM.m

function label_2=My_GMM(K)
%輸入K:聚類數,K個單高斯模型
%輸出label_2:聚的類,para_pi:單高斯權重,para_miu_new:高斯分佈參數μ,para_sigma:高斯分佈參數sigma
format long
eps=1e-15;  %定義迭代終止條件的eps
data=dlmread('E:\www.cnblogs.comkailugaji\data\iris\iris_data.txt');
%----------------------------------------------------------------------------------------------------
%對data作最大-最小歸一化處理
[data_num,~]=size(data);
X=(data-ones(data_num,1)*min(data))./(ones(data_num,1)*(max(data)-min(data)));
[X_num,X_dim]=size(X);
para_sigma=zeros(X_dim,X_dim,K);
%----------------------------------------------------------------------------------------------------
%隨機初始化K個聚類中心
rand_array=randperm(X_num);  %產生1~X_num之間整數的隨機排列
center=X(rand_array(1:K),:);  %隨機排列取前K個數,在X矩陣中取這K行做爲初始聚類中心
%根據上述聚類中心初始化參數
para_miu_new=center;  %初始化參數miu
para_pi=ones(1,K)./K;  %K類單高斯模型的權重
for k=1:K
    para_sigma(:,:,k)=eye(X_dim);  %K類單高斯模型的協方差矩陣,初始化爲單位陣
end
%歐氏距離,計算(X-para_miu)^2=X^2+para_miu^2-2*X*para_miu',矩陣大小爲X_num*K
distant=repmat(sum(X.*X,2),1,K)+repmat(sum(para_miu_new.*para_miu_new,2)',X_num,1)-2*X*para_miu_new';
%返回distant每行最小值所在的下標
[~,label_1]=min(distant,[],2);
for k=1:K
    X_k=X(label_1==k,:);  %X_k是一個(X_num/K, X_dim)的矩陣,把X矩陣分爲K類
    para_pi(k)=size(X_k,1)/X_num;  %將(每一類數據的個數/X_num)做爲para_pi的初始值
    para_sigma(:,:,k)=cov(X_k);  %para_sigma是一個(X_dim, X_dim)的矩陣,cov(矩陣)求的是每一列之間的協方差
end
%----------------------------------------------------------------------------------------------------
%EM算法
N_pdf=zeros(X_num,K);
while true
    para_miu=para_miu_new;
    %----------------------------------------------------------------------------------------------------
    %E步
    %單高斯分佈的機率密度函數N_pdf
    for k=1:K
        X_miu=X-repmat(para_miu(k,:),X_num,1);  %X-miu,(X_num, X_dim)的矩陣
        sigma_inv=inv(para_sigma(:,:,k));  %sigma的逆矩陣,(X_dim, X_dim)的矩陣//極可能出現奇異矩陣
        exp_up=sum((X_miu*sigma_inv).*X_miu,2);  %指數的冪,(X-miu)'*sigma^(-1)*(X-miu)
        coefficient=(2*pi)^(-X_dim/2)*sqrt(det(sigma_inv));  %高斯分佈的機率密度函數e左邊的係數
        N_pdf(:,k)=coefficient*exp(-0.5*exp_up);
    end
%    N_pdf=guass_pdf(X,K,para_miu,para_sigma);
    responsivity=N_pdf.*repmat(para_pi,X_num,1);  %響應度responsivity的分子,(X_num,K)的矩陣
    responsivity=responsivity./repmat(sum(responsivity,2),1,K);  %responsivity:在當前模型下第n個觀測數據來自第k個分模型的機率,即分模型k對觀測數據Xn的響應度
    %----------------------------------------------------------------------------------------------------
    %M步
    R_k=sum(responsivity,1);  %(1,K)的矩陣,把responsivity每一列求和
    %更新參數miu
    para_miu_new=diag(1./R_k)*responsivity'*X;
    %更新k個參數sigma
    for i=1:K
        X_miu=X-repmat(para_miu_new(i,:),X_num,1);
        para_sigma(:,:,i)=(X_miu'*(diag(responsivity(:,i))*X_miu))/R_k(i);
    end
    %更新參數pi
    para_pi=R_k/sum(R_k);
    %----------------------------------------------------------------------------------------------------
    %迭代終止條件
    if norm(para_miu_new-para_miu)<=eps
        break;
    end
end
%----------------------------------------------------------------------------------------------------
%聚類
[~,label_2]=max(responsivity,[],2);

succeed.m

function accuracy=succeed(K,id)
%輸入K:聚的類,id:訓練後的聚類結果,N*1的矩陣
N=size(id,1);   %樣本個數
p=perms(1:K);   %全排列矩陣
p_col=size(p,1);   %全排列的行數
new_label=zeros(N,p_col);   %聚類結果的全部可能取值,N*p_col
num=zeros(1,p_col);  %與真實聚類結果同樣的個數
real_label=dlmread('E:\www.cnblogs.comkailugaji\data\iris\iris_id.txt');
%將訓練結果全排列爲N*p_col的矩陣,每一列爲一種可能性
for i=1:N
    for j=1:p_col
        for k=1:K
            if id(i)==k
                new_label(i,j)=p(j,k)-1;
            end
        end
    end
end
%與真實結果比對,計算精確度
for j=1:p_col
    for i=1:N
        if new_label(i,j)==real_label(i)
                num(j)=num(j)+1;
        end
    end
end
accuracy=max(num)/N;

3.結果

>> label_1=My_GMM(3);
>> accuracy=succeed(3,label_1)

accuracy =

   0.966666666666667

4.注意

    GMM算法我只進行了一次計算準確度,由於有可能會出現奇異矩陣的狀況,致使算法出錯,如今我尚未想出如何解決奇異矩陣的問題,所以只給出了一次循環。望指正。github

補充:GMM的Python代碼:upload/GMM.py at master · wl-lei/upload · GitHub算法

相關文章
相關標籤/搜索