Keras 提供 Callback 接口來追蹤訓練過程當中的每一步結果,包括每個 batch 和每個 epoch。雖然名爲「回調函數」,但實際上想要擴展這功能須要繼承 keras.callbacks.Callback
類,該類提供兩個與模型訓練過程相關的屬性:html
params
:compile 模型時設定的參數;model
:模型對象。經過這一接口能夠實時可視化 fit
過程當中每個 batch 和每個 epoch 迭代過程當中的偏差大小變化。以《Neural Networks and Deep Learning - Chap3 Improving the way neural networks learn》爲例,假設咱們要訓練一個最簡單的神經網絡:markdown
這個只有一個神經元的神經網絡只有一個權重 w
和一個偏置 b
兩個待訓練的參數,假設要訓練的數據只有 (1, 0)
,在這裏比較 MSE 和 Cross Entropy 兩種代價函數的學習效果。網絡
首先構建這個模型:app
from keras import Sequential, initializers, optimizers from keras.layers import Activation, Dense import numpy as np def viz_keras_fit(w, b, runtime_plot=False, loss="mean_squared_error", act="sigmoid"): d = DrawCallback(runtime_plot=runtime_plot) # 初始化參數 w = initializers.Constant([w]) b = initializers.Constant([b]) x_train, y_train = np.array([1]), np.array([0]) model = Sequential() model.add(Dense(1, activation=act, input_shape=(1,), kernel_initializer=w, bias_initializer=b)) # Learning Rate = 0.15 sgd = optimizers.SGD(lr=0.15) model.compile(optimizer=sgd, loss=loss) model.fit(x = x_train, y = y_train, epochs=150, verbose=0, callbacks=[d]) # Callback List return d 複製代碼
初始參數仍然是 (2, 2)
換成 Cross Entropy 做爲 loss function 以後:函數
雖然實現了實時可視化,但繪圖所用的時間可能比一個 epoch 耗時更久,所以先記錄每一步的 loss 再繪圖會更好一些:oop
實時觀察模型的學習狀況能夠幫助咱們在初期選擇損失函數、激活函數、模型結構以及超參數等。如下是 DrawCallback
的實現:學習
import pylab as pl from IPython import display from keras.callbacks import Callback class DrawCallback(Callback): def __init__(self, runtime_plot=True): super().__init__() self.init_loss = None self.runtime_plot = runtime_plot self.xdata = [] self.ydata = [] def _plot(self, epoch=None): epochs = self.params.get("epochs") pl.ylim(0, int(self.init_loss*2)) pl.xlim(0, epochs) pl.plot(self.xdata, self.ydata) pl.xlabel('Epoch {}/{}'.format(epoch or epochs, epochs)) pl.ylabel('Loss {:.4f}'.format(self.ydata[-1])) def _runtime_plot(self, epoch): self._plot(epoch) display.clear_output(wait=True) display.display(pl.gcf()) pl.gcf().clear() def plot(self): self._plot() pl.show() def on_epoch_end(self, epoch, logs = None): logs = logs or {} loss = logs.get("loss") if self.init_loss is None: self.init_loss = loss self.xdata.append(epoch) self.ydata.append(loss) if self.runtime_plot: self._runtime_plot(epoch) 複製代碼