如今深度學習中通常咱們學習的參數都是連續的,由於這樣在反向傳播的時候才能夠對梯度進行更新。可是有的時候咱們也會遇到參數是離>散的狀況,這樣就沒有辦法進行反向傳播了,好比二值神經網絡。本文中講解了如何用pytorch
對二值化的參數進行梯度更新的straight-through estimator
算法。html
STE
核心的思想就是咱們的參數初始化的時候就是float
這樣的連續值,當咱們forward
的時候就將原來的連續的參數映射到{-1,, 1}帶入到網絡進行計算,這樣就能夠計算網絡的輸出。而後backward
的時候直接對原來float
的參數進行更新,而不是對二值化的參數更新。這樣能夠完成對整個網絡的更新了。
首先咱們對上面問題進行一下數學的講解。git
torch.sign
函數, 能夠理解爲取符號函數backward
的過程當中對$q$求梯度可得 $\frac{\partial loss}{\partial q}$loss
對r
梯度都是0backward
的過程咱們須要修改$\frac{\partial q}{\partial r}$這部分纔可使梯度繼續更新下去,因此對$\frac{\partial loss}{\partial r}$進行以下修改: $\frac{\partial q}{\partial r} = \frac{\partial loss}{\partial q} * 1\_{|r| \leq 1}$, 其中$1\_{|r| \leq 1}$ 能夠看做$Htanh(x) = Clip(x, -1, 1) = max(-1, min(1, x))$對$x$的求導過程, 也就是是說:
$$\frac{\partial loss}{\partial r} = \frac{\partial loss}{\partial q} \frac{\partial Htanh}{\partial r}$$github
首先咱們驗證一下使用torch.sign
會是參數的梯度基本上都是0:算法
>>> input = torch.randn(4, requires_grad = True) >>> output = torch.sign(input) >>> loss = output.mean() >>> loss.backward() >>> input tensor([-0.8673, -0.0299, -1.1434, -0.6172], requires_grad=True) >>> input.grad tensor([0., 0., 0., 0.])
咱們須要重寫sign
這個函數,就好像寫一個激活函數同樣。先看一下代碼, github源碼:LBSign.py
網絡
import torch class LBSign(torch.autograd.Function): @staticmethod def forward(ctx, input): return torch.sign(input) @staticmethod def backward(ctx, grad_output): return grad_output.clamp_(-1, 1)
接下來咱們作一下測試main.py
app
import torch from LBSign import LBSign if __name__ == '__main__': sign = LBSign.apply params = torch.randn(4, requires_grad = True) output = sign(params) loss = output.mean() loss.backward()
而後咱們發現有梯度了函數
>>> params tensor([-0.9143, 0.8993, -1.1235, -0.7928], requires_grad=True) >>> params.grad tensor([0.2500, 0.2500, 0.2500, 0.2500])
接下來咱們對代碼就行一下解釋pytorch文檔連接:學習
ctx
是保存的上下文信息,input
是輸入ctx
是保存的上下文信息,grad_output
能夠理解成 $\frac{\partial loss}{\partial q}$這一步的梯度信息,咱們須要作的就是讓
$$grad\_output * \frac{\partial Htanh}{\partial r}$$ 而不是讓pytorch
繼續默認的 $$grad\_output * \frac{\partial q}{\partial r}$$
可是咱們能夠從上面的公式能夠看出函數$Htanh$對$x$求導是1, 當$x \in [-1, 1]$,因此程序就能夠化簡成保留原來的梯度就好了,而後裁剪到其餘範圍的。測試
torch.autograd.Function
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
二值網絡,圍繞STE的那些事兒
Custom binarization layer with straight through estimator gives error
定義torch.autograd.Function的子類,本身定義某些操做,且定義反向求導函數ui