可視化 Keras 訓練過程

Keras 提供 Callback 接口來追蹤訓練過程當中的每一步結果,包括每個 batch 和每個 epoch。雖然名爲「回調函數」,但實際上想要擴展這功能須要繼承 keras.callbacks.Callback 類,該類提供兩個與模型訓練過程相關的屬性:html

  • params:compile 模型時設定的參數;
  • model:模型對象。

經過這一接口能夠實時可視化 fit 過程當中每個 batch 和每個 epoch 迭代過程當中的偏差大小變化。以《Neural Networks and Deep Learning - Chap3 Improving the way neural networks learn》爲例,假設咱們要訓練一個最簡單的神經網絡:markdown

network

這個只有一個神經元的神經網絡只有一個權重 w 和一個偏置 b 兩個待訓練的參數,假設要訓練的數據只有 (1, 0),在這裏比較 MSE 和 Cross Entropy 兩種代價函數的學習效果。網絡

首先構建這個模型:app

from keras import Sequential, initializers, optimizers
from keras.layers import Activation, Dense

import numpy as np

def viz_keras_fit(w, b, runtime_plot=False, loss="mean_squared_error", act="sigmoid"):
    d = DrawCallback(runtime_plot=runtime_plot)
    
    # 初始化參數
    w = initializers.Constant([w])
    b = initializers.Constant([b])

    x_train, y_train = np.array([1]), np.array([0])
    
    model = Sequential()
    model.add(Dense(1, 
        activation=act,
        input_shape=(1,),
        kernel_initializer=w,
        bias_initializer=b))

    # Learning Rate = 0.15
    sgd = optimizers.SGD(lr=0.15)
    model.compile(optimizer=sgd, loss=loss)

    model.fit(x = x_train,
        y = y_train,
        epochs=150,
        verbose=0,
        callbacks=[d]) # Callback List
    return d
複製代碼

初始參數仍然是 (2, 2) 換成 Cross Entropy 做爲 loss function 以後:函數

雖然實現了實時可視化,但繪圖所用的時間可能比一個 epoch 耗時更久,所以先記錄每一步的 loss 再繪圖會更好一些:oop

實時觀察模型的學習狀況能夠幫助咱們在初期選擇損失函數、激活函數、模型結構以及超參數等。如下是 DrawCallback 的實現:學習

import pylab as pl
from IPython import display
from keras.callbacks import Callback

class DrawCallback(Callback):
    def __init__(self, runtime_plot=True):
        super().__init__()
        self.init_loss = None
        self.runtime_plot = runtime_plot
        
        self.xdata = []
        self.ydata = []
    def _plot(self, epoch=None):
        epochs = self.params.get("epochs")
        pl.ylim(0, int(self.init_loss*2))
        pl.xlim(0, epochs)
    
        pl.plot(self.xdata, self.ydata)
        pl.xlabel('Epoch {}/{}'.format(epoch or epochs, epochs))
        pl.ylabel('Loss {:.4f}'.format(self.ydata[-1]))
        
    def _runtime_plot(self, epoch):
        self._plot(epoch)
        
        display.clear_output(wait=True)
        display.display(pl.gcf())
        pl.gcf().clear()
        
    def plot(self):
        self._plot()
        pl.show()
    
    def on_epoch_end(self, epoch, logs = None):
        logs = logs or {}
        loss = logs.get("loss")
        if self.init_loss is None:
            self.init_loss = loss
        self.xdata.append(epoch)
        self.ydata.append(loss)
        if self.runtime_plot:
            self._runtime_plot(epoch)
複製代碼

Notion 筆記:可視化 Keras 訓練過程spa

相關文章
相關標籤/搜索