Knowledge Distillation 知識蒸餾

stg.iobs.pingan.com.cn/download/mu…

Knowledge Distillation

知識蒸餾簡單來講將本來複雜的模型,用一個小模型代替,小模型直接學習大模型的預測結果。git

Temperature 的做用:

如圖 y1和y2,y3 的分數差距很是大,這樣的話和直接學label就沒有區別了,咱們須要擴大label之間的聯繫,因此引入Temperature係數來解決這個問題。ubuntu

import numpy as np
import sys
sys.path.append('utils/')

import tensorflow.keras
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping

# use non standard flow_from_directory
from image_preprocessing_ver2 import ImageDataGenerator
# it outputs y_batch that contains onehot targets and logits
# logits came from xception

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Lambda, concatenate, Activation
from tensorflow.keras.losses import categorical_crossentropy as logloss
from tensorflow.keras.metrics import categorical_accuracy, top_k_categorical_accuracy
from tensorflow.keras import backend as K

from mobilenet import get_mobilenet
from tensorflow.keras.applications.mobilenet import preprocess_input

import matplotlib.pyplot as plt
%matplotlib inline
複製代碼
data_dir = '/home/ubuntu/data/'
train_logits = np.load(data_dir + 'train_logits.npy')[()]
val_logits = np.load(data_dir + 'val_logits.npy')[()]
複製代碼
data_generator = ImageDataGenerator(
    data_format='channels_last',
    preprocessing_function=preprocess_input
)

# note: i'm also passing dicts of logits
train_generator = data_generator.flow_from_directory(
    data_dir + 'train', train_logits,
    target_size=(224, 224),
    batch_size=64
)

val_generator = data_generator.flow_from_directory(
    data_dir + 'val', val_logits,
    target_size=(224, 224),
    batch_size=64
)
複製代碼
temperature = 5.0
model = get_mobilenet(224, alpha=0.25, weight_decay=1e-5, dropout=0.1)

# remove softmax
model.layers.pop()

# usual probabilities
logits = model.layers[-1].output
probabilities = Activation('softmax')(logits)

# softed probabilities
logits_T = Lambda(lambda x: x/temperature)(logits)
probabilities_T = Activation('softmax')(logits_T)

output = concatenate([probabilities, probabilities_T])
model = Model(model.input, output)
# now model outputs 512 dimensional vectors
複製代碼
def knowledge_distillation_loss(y_true, y_pred, lambda_const):    
    
    # split in 
    # onehot hard true targets
    # logits from xception
    y_true, logits = y_true[:, :256], y_true[:, 256:]
    
    # convert logits to soft targets
    y_soft = K.softmax(logits/temperature)
    
    # split in 
    # usual output probabilities
    # probabilities made softer with temperature
    y_pred, y_pred_soft = y_pred[:, :256], y_pred[:, 256:]    
    
    return lambda_const*logloss(y_true, y_pred) + logloss(y_soft, y_pred_soft)
複製代碼
def accuracy(y_true, y_pred):
    y_true = y_true[:, :256]
    y_pred = y_pred[:, :256]
    return categorical_accuracy(y_true, y_pred)
    
def top_5_accuracy(y_true, y_pred):
    y_true = y_true[:, :256]
    y_pred = y_pred[:, :256]
    return top_k_categorical_accuracy(y_true, y_pred)
    
def categorical_crossentropy(y_true, y_pred):
    y_true = y_true[:, :256]
    y_pred = y_pred[:, :256]
    return logloss(y_true, y_pred)
    
# logloss with only soft probabilities and targets
def soft_logloss(y_true, y_pred):     
    logits = y_true[:, 256:]
    y_soft = K.softmax(logits/temperature)
    y_pred_soft = y_pred[:, 256:]    
    return logloss(y_soft, y_pred_soft)
    
lambda_const = 0.07

model.compile(
    optimizer=optimizers.SGD(lr=1e-1, momentum=0.9, nesterov=True), 
    loss=lambda y_true, y_pred: knowledge_distillation_loss(y_true, y_pred, lambda_const), 
    metrics=[accuracy, top_5_accuracy, categorical_crossentropy, soft_logloss]
)

model.fit_generator(
    train_generator, 
    steps_per_epoch=400, epochs=30, verbose=1,
    callbacks=[
        EarlyStopping(monitor='val_accuracy', patience=4, min_delta=0.01), 
        ReduceLROnPlateau(monitor='val_accuracy', factor=0.1, patience=2, epsilon=0.007)
    ],
    validation_data=val_generator, validation_steps=80, workers=4
)
複製代碼

Epoch 20/30 400/400 [==============================] - 23s - loss: 5.5787 - accuracy: 0.7448 - top_5_accuracy: 0.9116 - categorical_crossentropy: 1.2227 - soft_logloss: 5.4866 - val_loss: 5.5967 - val_accuracy: 0.6607 - val_top_5_accuracy: 0.8658 - val_categorical_crossentropy: 1.4738 - val_soft_logloss: 5.4871bash

相關文章
相關標籤/搜索