1.本身寫的計算auc的代碼,用scikit-learn的auc計算函數sklearn.metrics.
auc
(x, y, reorder=False)作了一些測試,結果是同樣的,若有錯誤,歡迎指正。python
思路:1.首先對預測值進行排序,排序的方式用了python自帶的函數sorted,詳見註釋。數組
2.對全部樣本按照預測值從小到大標記rank,rank其實就是index+1,index是排序後的sorted_pred數組中的索引函數
3.將全部正樣本的rank相加,遇到預測值相等的狀況,無論樣本的正負性,對rank要取平均值再相加測試
4.將rank相加的和減去正樣本排在正樣本以後的狀況,再除以總的組合數,獲得aucspa
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Wed May 3 10:48:28 2017 4 5 @author: Vincent 6 """ 7 import numpy as np 8 from sklearn import metrics 9 y = np.array( [1, 0, 0, 1, 1, 1, 0, 1, 1, 1]) 10 pred = np.array([0.9, 0.9,0.8, 0.8, 0.7,0.7,0.7,0.6,0.5,0.4]) 11 fpr, tpr, thresholds = metrics.roc_curve(y, pred, pos_label=1) 12 print(metrics.auc(fpr, tpr)) 13 def getAuc(labels, pred) : 14 '''將pred數組的索引值按照pred[i]的大小正序排序,返回的sorted_pred是一個新的數組, 15 sorted_pred[0]就是pred[i]中值最小的i的值,對於這個例子,sorted_pred[0]=8 16 ''' 17 sorted_pred = sorted(range(len(pred)), key = lambda i : pred[i]) 18 pos = 0.0 #正樣本個數 19 neg = 0.0 #負樣本個數 20 auc = 0.0 21 last_pre = pred[sorted_pred[0]] 22 count = 0.0 23 pre_sum = 0.0 #當前位置以前的預測值相等的rank之和,rank是從1開始的,因此在下面的代碼中就是i+1 24 pos_count = 0.0 #記錄預測值相等的樣本中標籤是正的樣本的個數 25 for i in range(len(sorted_pred)) : 26 if labels[sorted_pred[i]] > 0: 27 pos += 1 28 else: 29 neg += 1 30 if last_pre != pred[sorted_pred[i]]: #當前的預測機率值與前一個值不相同 31 #對於預測值相等的樣本rank須要取平均值,而且對rank求和 32 auc += pos_count * pre_sum / count 33 count = 1 34 pre_sum = i + 1 #更新爲當前的rank 35 last_pre = pred[sorted_pred[i]] 36 if labels[sorted_pred[i]] > 0: 37 pos_count = 1 #若是當前樣本是正樣本 ,則置爲1 38 else: 39 pos_count = 0 #反之置爲0 40 else: 41 pre_sum += i + 1 #記錄rank的和 42 count += 1 #記錄rank和對應的樣本數,pre_sum / count就是平均值了 43 if labels[sorted_pred[i]] > 0:#若是是正樣本 44 pos_count += 1 #正樣本數加1 45 auc += pos_count * pre_sum / count #加上最後一個預測值相同的樣本組 46 auc -= pos *(pos + 1) / 2 #減去正樣本在正樣本以前的狀況 47 auc = auc / (pos * neg) #除以總的組合數 48 return auc 49 print(getAuc(y, pred))
2.awk代碼code
1 #計算auc,輸入分別爲預測值(能夠乘以一個倍數以後轉化爲整數),該相同預測值的樣本個數,該相同預測值的正樣本個數 2 sort -t $'\t' -k 1,1n | awk -F"\t" 'BEGIN{ 3 OFS="\t"; 4 now_q=""; 5 begin_rank=1; 6 now_pos_num=0; 7 now_neg_num=0; 8 total_pos_rank=0; 9 total_pos_num=0; 10 total_neg_num=0; 11 }function clear(){ 12 begin_rank += now_pos_num + now_neg_num; 13 now_pos_num=0; 14 now_neg_num=0; 15 }function update(){ 16 now_pos_num += pos_num; 17 now_neg_num += neg_num; 18 }function output(){ 19 n = now_pos_num + now_neg_num; 20 avg_rank = begin_rank + (n-1)/2; 21 tmp_all_pos_rank = avg_rank * now_pos_num; 22 total_pos_rank += tmp_all_pos_rank; 23 total_pos_num += now_pos_num; 24 total_neg_num += now_neg_num; 25 }{ 26 q=$1; 27 show=$2; 28 clk=$3; 29 pos_num=clk; 30 neg_num=show-clk; 31 if(now_q!=q){ 32 if(now_q!=""){ 33 output(); 34 clear(); 35 } 36 now_q=q; 37 } 38 update(); 39 40 }END{ 41 output(); 42 auc=0; 43 m=total_pos_num; 44 n=total_neg_num; 45 if(m>0 && n>0){ 46 auc = (total_pos_rank-m*(m+1)/2) / (m*n); 47 } 48 print auc; 49 }'