keras multi-label classification 多標籤分類

問題:一個數據又多個標籤,一個樣本數據多個類別中的某幾類;好比一個病人的數據有多個疾病,一個文本有多種題材,因此標籤就是: [1,0,0,0,1,0,1] 這種高維稀疏類型,如何計算分類準確率?python

 

分類問題:網絡

二分類函數

多分類性能

多標籤ui

 

Keras metrics (性能度量)lua

介紹的比較好的一個博客:spa

https://machinelearningmastery.com/custom-metrics-deep-learning-keras-python/code

還有一個介紹loss的博客:orm

https://machinelearningmastery.com/how-to-choose-loss-functions-when-training-deep-learning-neural-networks/blog

metrics:在訓練的每一個batch結束的時候計算訓練集acc,若是提供驗證集(一個epoch結束計算驗證集acc),也同時計算驗證集的性能度量,分爲迴歸任務和分類任務,有不一樣的acc計算辦法;metrics 裏面能夠放 loss (迴歸問題)或者acc(分類問題);

A metric is a function that is used to judge the performance of your model.

A metric function is similar to a loss function, except that the results from evaluating a metric are not used when training the model. You may use any of the loss functions as a metric function.

metrics其實和loss相似,只是不用來指導網絡的訓練;通常根據具體問題具體要求採用不一樣的 metric 函數,衡量性能;

分類問題的不一樣acc計算方法:

  • Binary Accuracy: binary_accuracy, acc
  • Categorical Accuracy: categorical_accuracy, acc
  • Sparse Categorical Accuracy: sparse_categorical_accuracy
  • Top k Categorical Accuracy: top_k_categorical_accuracy (requires you specify a k parameter)
  • Sparse Top k Categorical Accuracy: sparse_top_k_categorical_accuracy (requires you specify a k parameter)

 

 

keras metrics 默認的accuracy:

metrics["accuracy"] :   == categorical_accuracy; 最快的驗證方法,訓練一個簡單網絡,同時輸出默認accuracy,categorical_accuracy,,binaray_accuracy, 對比就能夠知道;

或者看keras源碼,找到metrics默認設置:

 

  

 

 

多標籤分類問題:

[1,0,0,1,0] , [1,0,0,0,0] 分別是 y_pred, y_true:

若是使用 binary_accuracy : acc = 0.8;

if the prediction would be [0, 0, 0, 0, 0, 1]. And if the actual labels were [0, 0, 0, 0, 0, 0], the accuracy would be 5/6.;

 

 

 

 

訓練過程常見坑:

1.自定義loss:

自定義loss寫成函數的時候,keras compile() 裏面,要調用自定義的loss函數而不是隻給函數名:

model.compile(optimizer="adam", loss=self_defined_loss(), metrics=["accuracy"])
 
2. 關於top5 , top1 ACC:--(針對多分類不是多標籤問題)
一個圖片多是 [貓,狗,大象,老鼠,小皮球,房子]裏面的一種;咱們對每一個圖片輸出一個機率分佈 [0.3,0.2,0.1,0.1,0.3] , 若是:
top1: 機率最高的預測類別是否和真實標籤一致;
top5:機率最高的5個預測類別是否包含了真實標籤;
相關文章
相關標籤/搜索