(轉)自動微分(Automatic Differentiation)簡介——tensorflow核心原理

現代深度學習系統中(好比MXNet, TensorFlow等)都用到了一種技術——自動微分。在此以前,機器學習社區中不多發揮這個利器,通常都是用Backpropagation進行梯度求解,而後進行SGD等進行優化更新。手動實現過backprop算法的同窗應該能夠體會到其中的複雜性和易錯性,一個好的框架應該能夠很好地將這部分難點隱藏於用戶視角,而自動微分技術剛好能夠優雅解決這個問題。接下來咱們將一塊兒學習這個優雅的技術:-)。本文主要來源於陳天奇在華盛頓任教的課程CSE599G1: Deep Learning System和《Automatic differentiation in machine learning: a survey》。css

什麼是自動微分

微分求解大體能夠分爲4種方式:node

  • 手動求解法(Manual Differentiation)
  • 數值微分法(Numerical Differentiation)
  • 符號微分法(Symbolic Differentiation)
  • 自動微分法(Automatic Differentiation)

爲了講明白什麼是自動微分,咱們有必要了解其餘方法,作到有區分有對比,從而更加深刻理解自動微分技術。git

手動求解法

手動求解其實就對應咱們傳統的backprop算法,咱們求解出梯度公式,而後編寫代碼,代入實際數值,得出真實的梯度。在這樣的方式下,每一次咱們修改算法模型,都要修改對應的梯度求解算法,所以沒有很好的辦法解脫用戶手動編寫梯度求解的代碼,這也是爲何咱們須要自動微分技術的緣由。github

數值微分法

數值微分法是根據導數的原始定義: 算法

f′(x)=limh→0f(x+h)−f(x)hf′(x)=limh→0f(x+h)−f(x)hexpress


那麼只要hh取很小的數值,好比0.0001,那麼咱們能夠很方便求解導數,而且能夠對用戶隱藏求解過程,用戶只要給出目標函數和要求解的梯度的變量,程序能夠自動給出相應的梯度,這也是某種意義上的「自動微分」:-)。不幸的是,數值微分法計算量太大,求解速度是這四種方法中最慢的,更加雪上加霜的是,它引發的roundoff errortruncation error使其更加不具有實際應用場景,爲了彌補缺點,便有以下center difference approximation: 編程

f′(x)=limh→0f(x+h)−f(x−h)2hf′(x)=limh→0f(x+h)−f(x−h)2h數組


惋惜並不能徹底消除truncation error,只是將偏差減少。雖然數值微分法有如上缺點,可是因爲它實在是太簡單實現了,因而不少時候,咱們利用它來檢驗其餘算法的正確性,好比在實現backprop的時候,咱們用的」gradient check」就是利用數值微分法。緩存

 

符號微分法

符號微分是代替咱們第一種手動求解法的過程,利用代數軟件,實現微分的一些公式好比: 網絡

ddx(f(x)+g(x))=ddxf(x)+ddxg(x)ddxf(x)g(x)=(ddxf(x))g(x)+f(x)(ddxg(x))ddxf(x)g(x)=f′(x)g(x)−f(x)g′(x)g(x)2ddx(f(x)+g(x))=ddxf(x)+ddxg(x)ddxf(x)g(x)=(ddxf(x))g(x)+f(x)(ddxg(x))ddxf(x)g(x)=f′(x)g(x)−f(x)g′(x)g(x)2


而後對用戶提供的具備closed form的數學表達式進行「自動微分」求解,什麼是具備closed form的呢?也就是必須能寫成完整數學表達式的,不能有編程語言中的循環結構,條件結構等。所以若是能將問題轉化爲一個純數學符號問題,咱們能利用現有的代數軟件進行符號微分求解,這種程度意義上的「自動微分」其實已經很完美了。然而缺點咱們剛剛也說起過了,就是必需要有closed form的數學表達式,另外一個有名的缺點是「表達式膨脹」(expression swell)問題,若是不加當心就會使得問題符號微分求解的表達式急速「膨脹」,致使最終求解速度變慢,對於這個問題請看以下圖: 
這裏寫圖片描述
稍不注意,符號微分求解就會如上中間列所示,表達式急劇膨脹,致使問題求解也隨着變慢。

 

自動微分法

終於輪到咱們的主角登場,自動微分的存在依賴於它識破以下事實:

全部數值計算歸根結底是一系列有限的可微算子的組合

自動微分法是一種介於符號微分和數值微分的方法:數值微分強調一開始直接代入數值近似求解;符號微分強調直接對代數進行求解,最後才代入問題數值;自動微分將符號微分法應用於最基本的算子,好比常數,冪函數,指數函數,對數函數,三角函數等,而後代入數值,保留中間結果,最後再應用於整個函數。所以它應用至關靈活,能夠作到徹底向用戶隱藏微分求解過程,因爲它只對基本函數或常數運用符號微分法則,因此它能夠靈活結合編程語言的循環結構,條件結構等,使用自動微分和不使用自動微分對代碼整體改動很是小,而且因爲它的計算實際是一種圖計算,能夠對其作不少優化,這也是爲何該方法在現代深度學習系統中得以普遍應用。

自動微分Forward Mode

考察以下函數: 

f(x1,x2)=ln(x1)+x1x2−sin(x2)f(x1,x2)=ln(x1)+x1x2−sin(x2)


咱們能夠將其轉化爲以下計算圖: 
這裏寫圖片描述
轉化成如上DAG(有向無環圖)結構以後,咱們能夠很容易分步計算函數的值,並求取它每一步的導數值: 
這裏寫圖片描述
上表中左半部分是從左往右每一個圖節點的求值結果,右半部分是每一個節點對於x1x1的求導結果,好比v1˙=dvdx1v1˙=dvdx1,注意到每一步的求導都利用到上一步的求導結果,這樣不至於重複計算,所以也不會產生像符號微分法的」expression swell」問題。 
自動微分的forward mode很是符合咱們高數裏面學習的求導過程,只要您對求導法則還有印象,理解forward mode自不在話下。若是函數輸入輸出爲: 

R→RmR→Rm


那麼利用forward mode只需計算一次如上表右邊過程便可,很是高效。對於輸入輸出映射爲以下的: 

Rn→RmRn→Rm


這樣一個有nn個輸入的函數,求解函數梯度須要nn遍如上計算過程。然而實際算法模型中,好比神經網絡,一般輸入輸出是極其不成比例的,也就是: 

n>>mn>>m


那麼利用forward mode進行自動微分就過低效了,所以便有下面要介紹的reverse mode。

 

自動微分Reverse Mode

若是您理解神經網絡的backprop算法,那麼恭喜你,自動微分的backward mode其實就是一種通用的backprop算法,也就是backprop是reverse mode自動微分的一種特殊形式。從名字能夠看出,reverse mode和forward mode是一對相反過程,reverse mode從最終結果開始求導,利用最終輸出對每個節點進行求導,其過程以下紅色箭頭所示: 
這裏寫圖片描述 
其具體計算過程以下表所示: 
這裏寫圖片描述
上表左邊和以前的forward mode一致,用於求解函數值,右邊則是reverse mode的計算過程,注意必須從下網上看,也就是一開始先計算輸出yy對於節點v5v5的導數,用v¯¯¯5v¯5表示dydv5dydv5,這樣的記號能夠強調咱們對當前計算結果進行緩存,以便用於後續計算,而沒必要重複計算。由鏈式法則咱們能夠計算輸出對於每一個節點的導數。 
好比對於節點v3v3: 

dydv3=dydv5dv5dv3dydv3=dydv5dv5dv3


用另外一種記法變獲得: 

dydv3=v5¯¯¯¯¯dv5dv3dydv3=v5¯dv5dv3


好比對於節點v0v0: 

dydv0=dydv2dv2dv0+dydv3dv3dv0dydv0=dydv2dv2dv0+dydv3dv3dv0


若是用另外一種記法,即可得出: 

dydv0=v¯¯¯2dv2dv0+v¯¯¯3dv3dv0dydv0=v¯2dv2dv0+v¯3dv3dv0


和backprop算法同樣,咱們必須記住前向時當前節點發出的邊,而後在後向傳播的時候,能夠蒐集全部受到當前節點影響節點。 
如上的計算過程,對於像神經網絡這種模型,一般輸入是上萬到上百萬維,而輸出損失函數是1維的模型,只須要一遍reverse mode的計算過程,即可以求出輸出對於各個輸入的導數,從而輕鬆求取梯度用於後續優化更新。

 

自動微分的實現

這裏主要講解reverse mode的實現方式,forward mode的實現基本和reverse mode一致,可是因爲機器學習算法中大部分用reverse mode才能夠高效求解,因此它是咱們理解的重心。代碼設計輪廓來源於CSE599G1的做業,經過分析完成做業,能夠展現自動微分的簡潔性和靈活可用性。 
首先自動微分會將問題轉化成一種有向無環圖,所以咱們必須構造基本的圖部件,包括節點和邊。能夠先看看節點是如何實現的: 
這裏寫圖片描述
首先節點能夠分爲三種:

  • 常數節點
  • 變量節點
  • 帶操做算子節點

所以Node類中定義了op成員用於存儲節點的操做算子,const_attr表明節點的常數值,name是節點的標識,主要用於調試。 
對於邊的實現則簡單的多,每一個節點只要知道自己的輸入節點便可,所以用inputs來描述節點的關係。 
有了如上的定義,利用操做符重載,咱們能夠很簡單構造一個計算圖,舉一個簡單的例子: 

f(x1,x2)=x1x2+x2f(x1,x2)=x1x2+x2


對於如上函數,只要重載加法和乘法操做符,咱們能夠輕鬆獲得以下計算圖: 
這裏寫圖片描述

 

操做算子是自動微分最重要的組成部分,接下來咱們重點介紹,先上代碼: 
這裏寫圖片描述
從定義能夠看出,全部實際計算都落在各個操做算子中,上面代碼應該抽象一些,咱們來舉一個乘法算子的例子加以說明: 
這裏寫圖片描述
咱們重點講解一下gradient方法,它接收兩個參數,一個是node,也就是當前要計算的節點,而output_grad則是後面節點傳來的,咱們來看看它究竟是啥玩意,對於以下例子: 

y=f(x1∗x2)y=f(x1∗x2)


那麼要求yy關於x1x1的導數,那麼根據鏈式法則可得: 

∂y∂x1=∂y∂f∂f∂x1=∂y∂x1x2∂x1x2∂x1=output_grad∗x2∂y∂x1=∂y∂f∂f∂x1=∂y∂x1x2∂x1x2∂x1=output_grad∗x2


則output_grad就是上面的∂y∂f∂y∂f,計算yy對於x2x2相似。所以在程序中咱們會返回以下:

 

return [node.inputs[1] * output_grad, node.inputs[0] * output_grad]
  • 1

再來介紹一個特殊的op——PlaceHolderOp,它的做用就如同名字,起到佔位符的做用,也就是自動微分中的變量,它不會參與實際計算,只等待用戶給他提供實際值,所以他的實現以下: 
這裏寫圖片描述

瞭解了節點和操做算子的定義,接下來咱們考慮如何協調執行運算。首先是如何計算函數值,對於一幅計算圖,因爲節點與節點之間的計算有必定的依賴關係,好比必須先計算node1以後才能夠計算node2,那麼如何能正確處理好計算關係呢?一個簡單的方式是對圖節點進行拓撲排序,這樣能夠保證須要先計算的節點先獲得計算。這部分代碼由Executor掌控: 
這裏寫圖片描述

Executor是實際計算圖的引擎,用戶提供須要計算的圖和實際輸入,Executor計算相應的值和梯度。

如何從計算圖中計算函數的值,上面咱們已經介紹了,接下來是如何自動計算梯度。reverse mode的自動微分,要求從輸出到輸入節點,按照前後依賴關係,對各個節點求取輸出對於當前節點的梯度,那麼和咱們上面介紹的恰好相反,爲了獲得正確計算節點順序,咱們能夠將圖節點的拓撲排序倒序便可。代碼也很簡單,以下所示: 
這裏寫圖片描述
這裏先介紹一個新的算子——oneslike_op。他是一個和numpy自帶的oneslike函數同樣的算子,做用是構造reverse梯度圖的起點,由於最終輸出關於自己的梯度就是一個和輸出shape同樣的全1數組,引入oneslike_op可使得真實計算得以延後,所以gradients方法最終返回的不是真實的梯度,而是梯度計算圖,而後能夠複用Executor,計算實際的梯度值。 
緊接着是根據輸出節點,得到倒序的拓撲排序序列,而後遍歷序列,構造實際的梯度計算圖。咱們重點來介紹node_to_output_grad和node_to_output_grads_list這兩個字典的意義。 
先關注node_to_output_grads_list,他key是節點,value是一個梯度列表,表明什麼含義呢?先看以下部分計算圖: 
這裏寫圖片描述 
此時咱們要計算輸出yy關於節點n1n1的導數,那麼咱們觀察到他的發射邊鏈接的節點有n3,n4n3,n4,而對應n3,n4n3,n4節點調用相應op的gradient方法,會返回輸出yy關於各個輸入節點的導數。此時爲了準確計算輸出yy關於節點n1n1的導數,咱們須要將其發射邊關聯節點的計算梯度搜集起來,好比上面的例子,咱們須要蒐集: 

node_to_output_grads_list={n1:[∂y∂n3∂n3∂n1,∂y∂n4∂n4∂n1]}node_to_output_grads_list={n1:[∂y∂n3∂n3∂n1,∂y∂n4∂n4∂n1]}


一旦蒐集好對應輸出邊節點關於當前節點導數,那麼當前節點的導數即可以由鏈式法則計算得出,也就是: 

∂y∂n1=∂y∂n3∂n3∂n1+∂y∂n4∂n4∂n1∂y∂n1=∂y∂n3∂n3∂n1+∂y∂n4∂n4∂n1


所以node_to_output_grad字典存儲的就是節點對應的輸出關於節點的導數。通過gradients函數執行後,會返回須要求取輸出關於某節點的梯度計算圖: 
這裏寫圖片描述 
而對於Executor而言,它並不知道此時的圖是否被反轉,它只關注用戶實際輸入,還有計算相應的值而已。

 

自動梯度的應用

有了上面的大篇幅介紹,咱們其實已經實現了一個簡單的自動微分引擎了,接下來看如何使用: 
這裏寫圖片描述
使用至關簡單,咱們像編寫普通程序同樣,對變量進行各類操做,只要提供要求導數的變量,還有提供實際輸入,引擎能夠正確給出相應的梯度值。 
下面給出一個根據自動微分訓練Logistic Regression的例子:

  1.  
    import autodiff as ad
  2.  
    import numpy as np
  3.  
     
  4.  
     
  5.  
    def logistic_prob(_w):
  6.  
    def wrapper(_x):
  7.  
    return 1 / (1 + np.exp(-np.sum(_x * _w)))
  8.  
    return wrapper
  9.  
     
  10.  
     
  11.  
    def test_accuracy(_w, _X, _Y):
  12.  
    prob = logistic_prob(_w)
  13.  
    correct = 0
  14.  
    total = len(_Y)
  15.  
    for i in range(len(_Y)):
  16.  
    x = _X[i]
  17.  
    y = _Y[i]
  18.  
    p = prob(x)
  19.  
    if p >= 0.5 and y == 1.0:
  20.  
    correct += 1
  21.  
    elif p < 0.5 and y == 0.0:
  22.  
    correct += 1
  23.  
    print( "總數:%d, 預測正確:%d" % (total, correct))
  24.  
     
  25.  
     
  26.  
    def plot(N, X_val, Y_val, w, with_boundary=False):
  27.  
    import matplotlib.pyplot as plt
  28.  
    for i in range(N):
  29.  
    __x = X_val[i]
  30.  
    if Y_val[i] == 1:
  31.  
    plt.plot(__x[ 1], __x[2], marker='x')
  32.  
    else:
  33.  
    plt.plot(__x[ 1], __x[2], marker='o')
  34.  
    if with_boundary:
  35.  
    min_x1 = min(X_val[:, 1])
  36.  
    max_x1 = max(X_val[:, 1])
  37.  
    min_x2 = float(-w[ 0] - w[1] * min_x1) / w[2]
  38.  
    max_x2 = float(-w[ 0] - w[1] * max_x1) / w[2]
  39.  
    plt.plot([min_x1, max_x1], [min_x2, max_x2], '-r')
  40.  
     
  41.  
    plt.show()
  42.  
     
  43.  
     
  44.  
    def gen_2d_data(n):
  45.  
    x_data = np.random.random([n, 2])
  46.  
    y_data = np.ones(n)
  47.  
    for i in range(n):
  48.  
    d = x_data[i]
  49.  
    if d[0] + d[1] < 1:
  50.  
    y_data[i] = 0
  51.  
    x_data_with_bias = np.ones([n, 3])
  52.  
    x_data_with_bias[:, 1:] = x_data
  53.  
    return x_data_with_bias, y_data
  54.  
     
  55.  
     
  56.  
    def auto_diff_lr():
  57.  
    x = ad.Variable(name= 'x')
  58.  
    w = ad.Variable(name= 'w')
  59.  
    y = ad.Variable(name= 'y')
  60.  
     
  61.  
    # 注意,如下實現某些狀況會有很大的數值偏差,
  62.  
    # 因此通常真實系統實現會提供高階算子,從而減小數值偏差
  63.  
     
  64.  
    h = 1 / (1 + ad.exp(-ad.reduce_sum(w * x)))
  65.  
    L = y * ad.log(h) + ( 1 - y) * ad.log(1 - h)
  66.  
    w_grad, = ad.gradients(L, [w])
  67.  
    executor = ad.Executor([L, w_grad])
  68.  
     
  69.  
    N = 100
  70.  
    X_val, Y_val = gen_2d_data(N)
  71.  
    w_val = np.ones( 3)
  72.  
     
  73.  
    plot(N, X_val, Y_val, w_val)
  74.  
    executor = ad.Executor([L, w_grad])
  75.  
    test_accuracy(w_val, X_val, Y_val)
  76.  
    alpha = 0.01
  77.  
    max_iters = 300
  78.  
    for iteration in range(max_iters):
  79.  
    acc_L_val = 0
  80.  
    for i in range(N):
  81.  
    x_val = X_val[i]
  82.  
    y_val = np.array(Y_val[i])
  83.  
    L_val, w_grad_val = executor.run(feed_dict={w: w_val, x: x_val, y: y_val})
  84.  
    w_val += alpha * w_grad_val
  85.  
    acc_L_val += L_val
  86.  
    print( "iter = %d, likelihood = %s, w = %s" % (iteration, acc_L_val, w_val))
  87.  
    test_accuracy(w_val, X_val, Y_val)
  88.  
    plot(N, X_val, Y_val, w_val, True)
  89.  
     
  90.  
     
  91.  
    if __name__ == '__main__':
  92.  
    auto_diff_lr()
相關文章
相關標籤/搜索