【tf.keras】實現 F1 score、precision、recall 等 metric

tf.keras.metric 裏面居然沒有實現 F1 score、recall、precision 等指標,一開始以爲真難以想象。但這是有緣由的,這些指標在 batch-wise 上計算都沒有意義,須要在整個驗證集上計算,而 tf.keras 在訓練過程(包括驗證集)中計算 acc、loss 都是一個 batch 計算一次的,最後再平均起來。Keras 2.0 版本將 precision, recall, fbeta_score, fmeasure 等 metrics 移除了。python

雖然 tf.keras.metric 中沒有實現 f1 socre、precision、recall,但咱們能夠經過 tf.keras.callbacks.Callback 實現。即在每一個 epoch 末尾,在整個 val 上計算 f一、precision、recall。git

一些博客實現了二分類下的 f1 socre、precision、recall,以下所示:github

如下代碼實現了多分類下對驗證集 F1 值、precision、recall 的計算,而且保存 val_f1 值最好的模型:dom

import tensorflow as tf

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, recall_score, precision_score
import numpy as np
import os


class Metrics(tf.keras.callbacks.Callback):
    def __init__(self, valid_data):
        super(Metrics, self).__init__()
        self.validation_data = valid_data

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        val_predict = np.argmax(self.model.predict(self.validation_data[0]), -1)
        val_targ = self.validation_data[1]
        if len(val_targ.shape) == 2 and val_targ.shape[1] != 1:
            val_targ = np.argmax(val_targ, -1)

        _val_f1 = f1_score(val_targ, val_predict, average='macro')
        _val_recall = recall_score(val_targ, val_predict, average='macro')
        _val_precision = precision_score(val_targ, val_predict, average='macro')

        logs['val_f1'] = _val_f1
        logs['val_recall'] = _val_recall
        logs['val_precision'] = _val_precision
        print(" — val_f1: %f — val_precision: %f — val_recall: %f" % (_val_f1, _val_precision, _val_recall))
        return


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=10000, random_state=32)

# LeNet-5
model = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(32, 32, 3)),
    tf.keras.layers.Conv2D(6, 5, activation='relu'),
    tf.keras.layers.AveragePooling2D(),
    tf.keras.layers.Conv2D(16, 5, activation='relu'),
    tf.keras.layers.AveragePooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(120, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(84, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

if not os.path.exists('./checkpoints'):
    os.makedirs('./checkpoints')

# 按照 val_f1 保存模型
ck_callback = tf.keras.callbacks.ModelCheckpoint('./checkpoints/weights.{epoch:02d}-{val_f1:.4f}.hdf5',
                                                 monitor='val_f1', 
                                                 mode='max', verbose=2,
                                                 save_best_only=True,
                                                 save_weights_only=True)
tb_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=0)
model.fit(x_train, y_train,
          validation_data=(x_val, y_val),
          epochs=100,
          callbacks=[Metrics(valid_data=(x_val, y_val)),
                     ck_callback,
                     tb_callback])

注意 Metrics()ck_callback 兩個 callback 的順序,互換以後將報錯。spa

References

How to calculate F1 Macro in Keras? -- StackOverflow
How to compute f1 score for each epoch in Keras -- Thong Nguyen
keras如何求分類問題中的準確率和召回率? - 魚塘鄧少的回答 - 知乎
Keras 2.0 release notes -- keras-team/kerascode

相關文章
相關標籤/搜索