混淆矩陣在Matlab中PRtools模式識別工具箱的應用

聲明:本文用到的代碼均來自於PRTools(http://www.prtools.org)模式識別工具箱,並以matlab軟件進行實驗。ide

      混淆矩陣是模式識別中的經常使用工具,在PRTools工具箱中有直接的函數confmat可供引用。具體使用方法以下所示:函數

  [C,NE,LABLIST] = CONFMAT(LAB1,LAB2,METHOD,FID)
 
  INPUT
   LAB1        Set of labels
   LAB2        Set of labels
   METHOD      'count' (default) to count number of co-occurences in
                  LAB1 and LAB2, 'disagreement' to count relative
                    non-co-occurrence.
   FID         Write text result to file
 
  OUTPUT
   C           Confusion matrix
   NE          Total number of errors (empty labels are neglected)
   LABLIST     Unique labels in LAB1 and LAB2

   

       首先簡單理解一些詞語:工具

      

       TP(True Positive):被分類器正確分類的正元組。spa

       TN(True Negative):被分類器正確分類的負元組。code

       FP(False Positive):被錯誤標記爲正元組的負元組。orm

       FN(False Negative):被錯誤標記爲負元組的正元組。blog

       TP與TN告訴咱們分類器什麼時候分類正確,FP與FN告訴咱們分類器什麼時候分類錯誤。three

       對一個M類的數據集, 混淆矩陣(Confusion Matrix)是一個至少M×M的表,它的第i行第j列的數值表示爲第i類的元組被標記爲第j類的個數。ssl

       一個例子,以UCI數據集中的Ionosphere數據集爲例,調用PRtools工具箱中的混淆矩陣函數:element

(1)首先初始化ionosphere數據集合:

data=load('ionosphere.txt');
[m,k]=size(data);
data1=ones(m,k-1);
for i=1:k-1
    data1(:,i)=(data(:,i)-min(data(:,i)))/(max(data(:,i))-min(data(:,i)));
end
label=data(:,k);
[Y,I]=min(label);
if Y(1)==0
    for i=1:m
           label(i)=label(i)+1;
    end
end
a=dataset(data1,label);

(2)而後調用confmat.m函數:

[train,test]=gendat(a,0.5);
w=treec(train);
conf=confmat(test*w)

運行結果:

conf就是混淆矩陣,其矩陣數值含義對應上述表格。

若是不想用PRtools工具箱中的混淆矩陣函數,能夠直接自行編寫混淆矩陣代碼,以下所示,運行時可直接調用。

function [confmatrix] = cfmatrix(actual, predict, classlist, per)
% CFMATRIX calculates the confusion matrix for any prediction 
% algorithm that generates a list of classes to which the test 
% feature vectors are assigned
%
% Outputs: confusion matrix
%
%                 Actual Classes
%                   p       n
%              ___|_____|______| 
%    Predicted  p'|     |      |
%      Classes  n'|     |      |
% 
% Inputs: 
% 1. actual / 2. predict
% The inputs provided are the 'actual' classes vector
% and the 'predict'ed classes vector. The actual classes are the classes
% to which the input feature vectors belong. The predicted classes are the 
% class to which the input feature vectors are predicted to belong to, 
% based on a prediction algorithm. 
% The length of actual class vector and the predicted class vector need to 
% be the same. If they are not the same, an error message is displayed. 
% 3. classlist
% The third input provides the list of all the classes {p,n,...} for which 
% the classification is being done. All classes are numbers.
% 4. per = 1/0 (default = 0)
% This parameter when set to 1 provides the values in the confusion matrix 
% as percentages. The default provides the values in numbers.
%
% Example:
% >> a = [ 1 2 3 1 2 3 1 1 2 3 2 1 1 2 3];
% >> b = [ 1 2 3 1 2 3 1 1 1 2 2 1 2 1 3];
% >> Cf = cfmatrix(a, b);
%
% [Avinash Uppuluri: avinash_uv@yahoo.com: Last modified: 08/21/08]

% If classlist not entered: make classlist equal to all 
% unique elements of actual
if (nargin < 2)
   error('Not enough input arguments.');
elseif (nargin == 2)
    classlist = unique(actual); % default values from actual
    per = 0; % default is numbers and input 1 for percentage
elseif (nargin == 3)
    per = 0; % default is numbers and input 1 for percentage
end

if (length(actual) ~= length(predict))
    error('First two inputs need to be vectors with equal size.');
elseif ((size(actual,1) ~= 1) && (size(actual,2) ~= 1))
    error('First input needs to be a vector and not a matrix');
elseif ((size(predict,1) ~= 1) && (size(predict,2) ~= 1))
    error('Second input needs to be a vector and not a matrix');
end
format short g;
n_class = length(classlist);
line_two = '----------';
line_three = '_________|';
for i = 1:n_class
    obind_class_i = find(actual == classlist(i));
    prind_class_i = find(predict == classlist(i));
    confmatrix(i,i) = length(intersect(obind_class_i,prind_class_i));
    for j = 1:n_class
        %if (j ~= i)
        if (j < i)
        % observed j predicted i
        confmatrix(i,j) = length(find(actual(prind_class_i) == classlist(j))); 
        % observed i predicted j
        confmatrix(j,i) = length(find(predict(obind_class_i) == classlist(j)));
        end
    end
    line_two = strcat(line_two,'---',num2str(classlist(i)),'-----');
    line_three = strcat(line_three,'__________');
end

if (per == 1)
    confmatrix = (confmatrix ./ length(actual)).*100;
end

% output to screen
disp('------------------------------------------');
disp('             Actual Classes');
disp(line_two);
disp('Predicted|                     ');
disp('  Classes|                     ');
disp(line_three);

for i = 1:n_class
    temps = sprintf('       %d             ',i);
    for j = 1:n_class
    temps = strcat(temps,sprintf(' |    %2.1f    ',confmatrix(i,j)));
    end
    disp(temps);
    clear temps
end
disp('------------------------------------------');
        
View Code

      混淆矩陣的概念其實很好理解,接下來引伸幾個很好理解的術語的概念(P:正元組數目,N:負元組數目):
      準確率:TP+TN/P+N    

      錯誤率:FP+FN/P+N

      敏感度、召回率:TP/P

      精度:TP/TP+FP

      本文主要是從PRtools工具箱中混淆矩陣函數的使用來簡單介紹瞭解混淆矩陣的概念,若有不正確的地方,歡迎指正。       

相關文章
相關標籤/搜索