聲明:本文用到的代碼均來自於PRTools(http://www.prtools.org)模式識別工具箱,並以matlab軟件進行實驗。ide
混淆矩陣是模式識別中的經常使用工具,在PRTools工具箱中有直接的函數confmat可供引用。具體使用方法以下所示:函數
[C,NE,LABLIST] = CONFMAT(LAB1,LAB2,METHOD,FID) INPUT LAB1 Set of labels LAB2 Set of labels METHOD 'count' (default) to count number of co-occurences in LAB1 and LAB2, 'disagreement' to count relative non-co-occurrence. FID Write text result to file OUTPUT C Confusion matrix NE Total number of errors (empty labels are neglected) LABLIST Unique labels in LAB1 and LAB2
首先簡單理解一些詞語:工具
TP(True Positive):被分類器正確分類的正元組。spa
TN(True Negative):被分類器正確分類的負元組。code
FP(False Positive):被錯誤標記爲正元組的負元組。orm
FN(False Negative):被錯誤標記爲負元組的正元組。blog
TP與TN告訴咱們分類器什麼時候分類正確,FP與FN告訴咱們分類器什麼時候分類錯誤。three
對一個M類的數據集, 混淆矩陣(Confusion Matrix)是一個至少M×M的表,它的第i行第j列的數值表示爲第i類的元組被標記爲第j類的個數。ssl
一個例子,以UCI數據集中的Ionosphere數據集爲例,調用PRtools工具箱中的混淆矩陣函數:element
(1)首先初始化ionosphere數據集合:
data=load('ionosphere.txt'); [m,k]=size(data); data1=ones(m,k-1); for i=1:k-1 data1(:,i)=(data(:,i)-min(data(:,i)))/(max(data(:,i))-min(data(:,i))); end label=data(:,k); [Y,I]=min(label); if Y(1)==0 for i=1:m label(i)=label(i)+1; end end a=dataset(data1,label);
(2)而後調用confmat.m函數:
[train,test]=gendat(a,0.5); w=treec(train); conf=confmat(test*w)
運行結果:
conf就是混淆矩陣,其矩陣數值含義對應上述表格。
若是不想用PRtools工具箱中的混淆矩陣函數,能夠直接自行編寫混淆矩陣代碼,以下所示,運行時可直接調用。
function [confmatrix] = cfmatrix(actual, predict, classlist, per) % CFMATRIX calculates the confusion matrix for any prediction % algorithm that generates a list of classes to which the test % feature vectors are assigned % % Outputs: confusion matrix % % Actual Classes % p n % ___|_____|______| % Predicted p'| | | % Classes n'| | | % % Inputs: % 1. actual / 2. predict % The inputs provided are the 'actual' classes vector % and the 'predict'ed classes vector. The actual classes are the classes % to which the input feature vectors belong. The predicted classes are the % class to which the input feature vectors are predicted to belong to, % based on a prediction algorithm. % The length of actual class vector and the predicted class vector need to % be the same. If they are not the same, an error message is displayed. % 3. classlist % The third input provides the list of all the classes {p,n,...} for which % the classification is being done. All classes are numbers. % 4. per = 1/0 (default = 0) % This parameter when set to 1 provides the values in the confusion matrix % as percentages. The default provides the values in numbers. % % Example: % >> a = [ 1 2 3 1 2 3 1 1 2 3 2 1 1 2 3]; % >> b = [ 1 2 3 1 2 3 1 1 1 2 2 1 2 1 3]; % >> Cf = cfmatrix(a, b); % % [Avinash Uppuluri: avinash_uv@yahoo.com: Last modified: 08/21/08] % If classlist not entered: make classlist equal to all % unique elements of actual if (nargin < 2) error('Not enough input arguments.'); elseif (nargin == 2) classlist = unique(actual); % default values from actual per = 0; % default is numbers and input 1 for percentage elseif (nargin == 3) per = 0; % default is numbers and input 1 for percentage end if (length(actual) ~= length(predict)) error('First two inputs need to be vectors with equal size.'); elseif ((size(actual,1) ~= 1) && (size(actual,2) ~= 1)) error('First input needs to be a vector and not a matrix'); elseif ((size(predict,1) ~= 1) && (size(predict,2) ~= 1)) error('Second input needs to be a vector and not a matrix'); end format short g; n_class = length(classlist); line_two = '----------'; line_three = '_________|'; for i = 1:n_class obind_class_i = find(actual == classlist(i)); prind_class_i = find(predict == classlist(i)); confmatrix(i,i) = length(intersect(obind_class_i,prind_class_i)); for j = 1:n_class %if (j ~= i) if (j < i) % observed j predicted i confmatrix(i,j) = length(find(actual(prind_class_i) == classlist(j))); % observed i predicted j confmatrix(j,i) = length(find(predict(obind_class_i) == classlist(j))); end end line_two = strcat(line_two,'---',num2str(classlist(i)),'-----'); line_three = strcat(line_three,'__________'); end if (per == 1) confmatrix = (confmatrix ./ length(actual)).*100; end % output to screen disp('------------------------------------------'); disp(' Actual Classes'); disp(line_two); disp('Predicted| '); disp(' Classes| '); disp(line_three); for i = 1:n_class temps = sprintf(' %d ',i); for j = 1:n_class temps = strcat(temps,sprintf(' | %2.1f ',confmatrix(i,j))); end disp(temps); clear temps end disp('------------------------------------------');
混淆矩陣的概念其實很好理解,接下來引伸幾個很好理解的術語的概念(P:正元組數目,N:負元組數目):
準確率:TP+TN/P+N
錯誤率:FP+FN/P+N
敏感度、召回率:TP/P
精度:TP/TP+FP
本文主要是從PRtools工具箱中混淆矩陣函數的使用來簡單介紹瞭解混淆矩陣的概念,若有不正確的地方,歡迎指正。