本文先介紹通常的梯度降低法是如何更新參數的,而後介紹 Adam 如何更新參數,以及 Adam 如何和學習率衰減結合。網絡
梯度降低法參數更新公式:
\[ \theta_{t+1} = \theta_{t} - \eta \cdot \nabla J(\theta_t) \]框架
其中,\(\eta\) 是學習率,\(\theta_t\) 是第 \(t\) 輪的參數,\(J(\theta_t)\) 是損失函數,\(\nabla J(\theta_t)\) 是梯度。函數
在最簡單的梯度降低法中,學習率 \(\eta\) 是常數,是一個須要實現設定好的超參數,在每輪參數更新中都不變,在一輪更新中各個參數的學習率也都同樣。學習
爲了表示簡便,令 \(g_t = \nabla J(\theta_t)\),因此梯度降低法能夠表示爲:
\[ \theta_{t+1} = \theta_{t} - \eta \cdot g_t \]優化
Adam,全稱 Adaptive Moment Estimation,是一種優化器,是梯度降低法的變種,用來更新神經網絡的權重。spa
Adam 更新公式:
\[ \begin{aligned} m_{t} &=\beta_{1} m_{t-1}+\left(1-\beta_{1}\right) g_{t} \\ v_{t} &=\beta_{2} v_{t-1}+\left(1-\beta_{2}\right) g_{t}^{2} \\ \hat{m}_{t} &=\frac{m_{t}}{1-\beta_{1}^{t}} \\ \hat{v}_{t} &=\frac{v_{t}}{1-\beta_{2}^{t}} \\ \theta_{t+1}&=\theta_{t}-\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon} \hat{m}_{t} \end{aligned} \]orm
在 Adam 原論文以及一些深度學習框架中,默認值爲 \(\eta = 0.001\),\(\beta_1 = 0.9\),\(\beta_2 = 0.999\),\(\epsilon = 1e-8\)。其中,\(\beta_1\) 和 \(\beta_2\) 都是接近 1 的數,\(\epsilon\) 是爲了防止除以 0。\(g_{t}\) 表示梯度。htm
咋一看很複雜,接下一一分解:blog
這是對梯度和梯度的平方進行滑動平均,即便得每次的更新都和歷史值相關。
中間兩行:
\[ \begin{aligned} \hat{m}_{t} &=\frac{m_{t}}{1-\beta_{1}^{t}} \\ \hat{v}_{t} &=\frac{v_{t}}{1-\beta_{2}^{t}} \end{aligned} \]
這是對初期滑動平均誤差較大的一個修正,叫作 bias correction,當 \(t\) 愈來愈大時,\(1-\beta_{1}^{t}\) 和 \(1-\beta_{2}^{t}\) 都趨近於 1,這時 bias correction 的任務也就完成了。
最後一行:
\[ \theta_{t+1}=\theta_{t}-\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon} \hat{m}_{t} \]
這是參數更新公式。
學習率爲 \(\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon}\),每輪的學習率再也不保持不變,在一輪中,每一個參數的學習率也不同了,這是由於 \(\eta\) 除以了每一個參數 \(\frac{1}{1- \beta_2} = 1000\) 輪梯度均方和的平方根,即 \(\sqrt{\frac{1}{1000}\sum_{k = t-999}^{t}g_k^2}\)。而每一個參數的梯度都是不一樣的,因此每一個參數的學習率即便在同一輪也就不同了。(可能會有疑問,\(t\) 前面沒有 999 輪更新怎麼辦,那就有多少輪就算多少輪,這個時候還有 bias correction 在。)
而參數更新的方向也不僅是當前輪的梯度 \(g_t\) 了,而是當前輪和過去共 \(\frac{1}{1- \beta_1} = 10\) 輪梯度的平均。
有關滑動平均的理解,能夠參考我以前的博客:理解滑動平均(exponential moving average)。
在 StackOverflow 上有一個問題 Should we do learning rate decay for adam optimizer - Stack Overflow,我也想過這個問題,對 Adam 這些自適應學習率的方法,還應不該該進行 learning rate decay?
論文 《DECOUPLED WEIGHT DECAY REGULARIZATION》的 Section 4.1 有提到:
Since Adam already adapts its parameterwise learning rates it is not as common to use a learning rate multiplier schedule with it as it is with SGD, but as our results show such schedules can substantially improve Adam’s performance, and we advocate not to overlook their use for adaptive gradient algorithms.
上述論文是建議咱們在用 Adam 的同時,也能夠用 learning rate decay。
我也簡單的作了個實驗,在 cifar-10 數據集上訓練 LeNet-5 模型,一個採用學習率衰減 tf.keras.callbacks.ReduceLROnPlateau(patience=5),另外一個不用。optimizer 爲 Adam 並使用默認的參數,\(\eta = 0.001\)。結果以下:
加入學習率衰減和不加兩種狀況在 test 集合上的 accuracy 分別爲: 0.5617 和 0.5476。(實驗結果取了兩次的平均,實驗結果的偶然性仍是有的)
經過上面的小實驗,咱們能夠知道,學習率衰減仍是有用的。(固然,這裏的小實驗僅能表明一小部分狀況,想要說明學習率衰減百分之百有效果,得有理論上的證實。)
固然,在設置超參數時就能夠調低 \(\eta\) 的值,使得不用學習率衰減也能夠達到很好的效果,只不過參數更新變慢了。
將學習率從默認的 0.001 改爲 0.0001,epoch 增大到 120,實驗結果以下所示:
加入學習率衰減和不加兩種狀況在 test 集合上的 accuracy 分別爲: 0.5636 和 0.5688。(三次實驗平均,實驗結果仍具備偶然性)
這個時候,使用學習率衰減帶來的影響可能很小。
那麼問題來了,Adam 作不作學習率衰減呢?
我我的會選擇作學習率衰減。(僅供參考吧。)在初始學習率設置較大的時候,作學習率衰減比不作要好;而當初始學習率設置就比較小的時候,作學習率衰減彷佛有點多餘,但從 validation set 上的效果看,作了學習率衰減仍是能夠有丁點提高的。
ReduceLROnPlateau 在 val_loss 正常降低的時候,對學習率是沒有影響的,只有在 patience(默認爲 10)個 epoch 內,val_loss 都不降低 1e-4 或者直接上升了,這個時候下降學習率確實是能夠很明顯提高模型訓練效果的,在 val_acc 曲線上看到一個快速上升的過程。對於其它類型的學習率衰減,這裏沒有過多地介紹。
從上述學習率曲線來看,Adam 作學習率衰減,是對 \(\eta\) 進行,而不是對 \(\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon}\) 進行,但有區別嗎?
學習率衰減通常以下:
exponential_decay:
decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
natural_exp_decay:
decayed_learning_rate = learning_rate * exp(-decay_rate * global_step / decay_steps)
ReduceLROnPlateau
若是被監控的值(如‘val_loss’)在 patience 個 epoch 內都沒有降低,那麼學習率衰減,乘以一個 factor
decayed_learning_rate = learning_rate * factor
這些學習率衰減都是直接在原學習率上乘以一個 factor ,對 \(\eta\) 或對 \(\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon}\) 操做,結果都是同樣的。
[1] An overview of gradient descent optimization algorithms -- Sebastian Ruder
[2] Should we do learning rate decay for adam optimizer - Stack Overflow
[3] Tensorflow中learning rate decay的奇技淫巧 -- Elevanth
[4] Loshchilov, I., & Hutter, F. (2017). Decoupled Weight Decay Regularization. ICLR 2019. Retrieved from http://arxiv.org/abs/1711.05101