【tf.keras】AdamW: Adam with Weight decay

論文 Decoupled Weight Decay Regularization 中提到,Adam 在使用時,L2 regularization 與 weight decay 並不等價,並提出了 AdamW,在神經網絡須要正則項時,用 AdamW 替換 Adam+L2 會獲得更好的性能。python

TensorFlow 2.0 在 tensorflow_addons 庫裏面實現了 AdamW,目前在 Mac 和 Linux 上能夠直接 pip install tensorflow_addons,在 windows 上還不支持,但也能夠直接把這個倉庫下載下來使用。git

下面是一個利用 AdamW 的示例程序(TF 2.0, tf.keras):(如下程序中,AdamW 的結果不如 Adam,這是由於模型比較簡單,加入 regularization 反而影響性能)github

import tensorflow as tf
import os
from tensorflow_addons.optimizers import AdamW

import numpy as np

from tensorflow.python.keras import backend as K
from tensorflow.python.util.tf_export import keras_export
from tensorflow.keras.callbacks import Callback


def lr_schedule(epoch):
    """Learning Rate Schedule
    Learning rate is scheduled to be reduced after 20, 30 epochs.
    Called automatically every epoch as part of callbacks during training.
    # Arguments
        epoch (int): The number of epochs
    # Returns
        lr (float32): learning rate
    """
    lr = 1e-3

    if epoch >= 30:
        lr *= 1e-2
    elif epoch >= 20:
        lr *= 1e-1
    print('Learning rate: ', lr)
    return lr


def wd_schedule(epoch):
    """Weight Decay Schedule
    Weight decay is scheduled to be reduced after 20, 30 epochs.
    Called automatically every epoch as part of callbacks during training.
    # Arguments
        epoch (int): The number of epochs
    # Returns
        wd (float32): weight decay
    """
    wd = 1e-4

    if epoch >= 30:
        wd *= 1e-2
    elif epoch >= 20:
        wd *= 1e-1
    print('Weight decay: ', wd)
    return wd


# just copy the implement of LearningRateScheduler, and then change the lr with weight_decay
@keras_export('keras.callbacks.WeightDecayScheduler')
class WeightDecayScheduler(Callback):
    """Weight Decay Scheduler.

    Arguments:
        schedule: a function that takes an epoch index as input
            (integer, indexed from 0) and returns a new
            weight decay as output (float).
        verbose: int. 0: quiet, 1: update messages.

    ```python
    # This function keeps the weight decay at 0.001 for the first ten epochs
    # and decreases it exponentially after that.
    def scheduler(epoch):
      if epoch < 10:
        return 0.001
      else:
        return 0.001 * tf.math.exp(0.1 * (10 - epoch))

    callback = WeightDecayScheduler(scheduler)
    model.fit(data, labels, epochs=100, callbacks=[callback],
              validation_data=(val_data, val_labels))
    ```
    """

    def __init__(self, schedule, verbose=0):
        super(WeightDecayScheduler, self).__init__()
        self.schedule = schedule
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'weight_decay'):
            raise ValueError('Optimizer must have a "weight_decay" attribute.')
        try:  # new API
            weight_decay = float(K.get_value(self.model.optimizer.weight_decay))
            weight_decay = self.schedule(epoch, weight_decay)
        except TypeError:  # Support for old API for backward compatibility
            weight_decay = self.schedule(epoch)
        if not isinstance(weight_decay, (float, np.float32, np.float64)):
            raise ValueError('The output of the "schedule" function '
                             'should be float.')
        K.set_value(self.model.optimizer.weight_decay, weight_decay)
        if self.verbose > 0:
            print('\nEpoch %05d: WeightDecayScheduler reducing weight '
                  'decay to %s.' % (epoch + 1, weight_decay))

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logs['weight_decay'] = K.get_value(self.model.optimizer.weight_decay)


if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = '1'

    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, enable=True)
    print(gpus)
    cifar10 = tf.keras.datasets.cifar10

    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(16, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
        tf.keras.layers.AveragePooling2D(),
        tf.keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        tf.keras.layers.AveragePooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    optimizer = AdamW(learning_rate=lr_schedule(0), weight_decay=wd_schedule(0))
    # optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

    tb_callback = tf.keras.callbacks.TensorBoard(os.path.join('logs', 'adamw'),
                                                 profile_batch=0)
    lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_schedule)
    wd_callback = WeightDecayScheduler(wd_schedule)

    model.compile(optimizer=optimizer,
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    model.fit(x_train, y_train, epochs=40, validation_split=0.1,
              callbacks=[tb_callback, lr_callback, wd_callback])

    model.evaluate(x_test, y_test, verbose=2)

以上代碼實現了在 learning rate decay 時使用 AdamW,雖然只能是在 epoch 層面進行學習率衰減。windows

在使用 AdamW 時,若是要使用 learning rate decay,那麼對 weight_decay 的值要進行一樣的學習率衰減,否則訓練會崩掉。網絡

References

How to use AdamW correctly? -- wuliytTaotao
Loshchilov, I., & Hutter, F. Decoupled Weight Decay Regularization. ICLR 2019. Retrieved from http://arxiv.org/abs/1711.05101性能

相關文章
相關標籤/搜索