pytorch實現簡單的straight-through estimator(STE)

如今深度學習中通常咱們學習的參數都是連續的,由於這樣在反向傳播的時候才能夠對梯度進行更新。可是有的時候咱們也會遇到參數是離>散的狀況,這樣就沒有辦法進行反向傳播了,好比二值神經網絡。本文中講解了如何用pytorch對二值化的參數進行梯度更新的straight-through estimator算法。html

Question

STE核心的思想就是咱們的參數初始化的時候就是float這樣的連續值,當咱們forward的時候就將原來的連續的參數映射到{-1,, 1}帶入到網絡進行計算,這樣就能夠計算網絡的輸出。而後backward的時候直接對原來float的參數進行更新,而不是對二值化的參數更新。這樣能夠完成對整個網絡的更新了。
首先咱們對上面問題進行一下數學的講解。git

  • 咱們但願參數的範圍是$r \in \mathbb{R}$
  • 咱們能夠獲得二值化的參數 $q = Sign(r)$, $Sign$函數能夠參考torch.sign函數, 能夠理解爲取符號函數
  • backward的過程當中對$q$求梯度可得 $\frac{\partial loss}{\partial q}$
  • 對於$\frac{\partial q}{\partial r} = 0$, 因此能夠得出 $\frac{\partial loss}{\partial r} = 0$, 這樣的話咱們就沒法完成對參>數的更新,由於每次lossr梯度都是0
  • 因此backward的過程咱們須要修改$\frac{\partial q}{\partial r}$這部分纔可使梯度繼續更新下去,因此對$\frac{\partial loss}{\partial r}$進行以下修改: $\frac{\partial q}{\partial r} = \frac{\partial loss}{\partial q} * 1\_{|r| \leq 1}$, 其中

$1\_{|r| \leq 1}$ 能夠看做$Htanh(x) = Clip(x, -1, 1) = max(-1, min(1, x))$對$x$的求導過程, 也就是是說:
$$\frac{\partial loss}{\partial r} = \frac{\partial loss}{\partial q} \frac{\partial Htanh}{\partial r}$$github

Example

torch.sign

首先咱們驗證一下使用torch.sign會是參數的梯度基本上都是0:算法

>>> input = torch.randn(4, requires_grad = True)
>>> output = torch.sign(input)
>>> loss = output.mean()
>>> loss.backward()
>>> input
tensor([-0.8673, -0.0299, -1.1434, -0.6172], requires_grad=True)
>>> input.grad
tensor([0., 0., 0., 0.])

demo

咱們須要重寫sign這個函數,就好像寫一個激活函數同樣。先看一下代碼, github源碼:
LBSign.py網絡

import torch

class LBSign(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        return torch.sign(input)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clamp_(-1, 1)

接下來咱們作一下測試
main.pyapp

import torch
from LBSign import LBSign

if __name__ == '__main__':

    sign = LBSign.apply
    params = torch.randn(4, requires_grad = True)                                                                           
    output = sign(params)
    loss = output.mean()
    loss.backward()

而後咱們發現有梯度了函數

>>> params
tensor([-0.9143,  0.8993, -1.1235, -0.7928], requires_grad=True)
>>> params.grad
tensor([0.2500, 0.2500, 0.2500, 0.2500])

explain

接下來咱們對代碼就行一下解釋pytorch文檔連接:學習

  • forward中的參數ctx是保存的上下文信息,input是輸入
  • backward中的參數ctx是保存的上下文信息,grad_output能夠理解成 $\frac{\partial loss}{\partial q}$這一步的梯度信息,

咱們須要作的就是讓
$$grad\_output * \frac{\partial Htanh}{\partial r}$$ 而不是讓pytorch繼續默認的 $$grad\_output * \frac{\partial q}{\partial r}$$
可是咱們能夠從上面的公式能夠看出函數$Htanh$對$x$求導是1, 當$x \in [-1, 1]$,因此程序就能夠化簡成保留原來的梯度就好了,而後裁剪到其餘範圍的。測試

reference

torch.autograd.Function
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
二值網絡,圍繞STE的那些事兒
Custom binarization layer with straight through estimator gives error
定義torch.autograd.Function的子類,本身定義某些操做,且定義反向求導函數ui

相關文章
相關標籤/搜索