基於 JAX 咱們能夠輕鬆地實現一個深度學習的框架

這是我參與8月更文挑戰的第10天,活動詳情查看:8月更文挑戰html

有關 JAX 的概述

JAX 是什麼

JAX 的前身是 Autograd ,也就是說 JAX 是 Autograd 升級版本,JAX 能夠對 Python 和 NumPy 程序進行自動微分。能夠經過 Python的大量特徵子集進行區分,包括循環、分支、遞歸和閉包語句進行自動求導,也能夠求三階導數(三階導數是由原函數導數的導數的導數。 所謂三階導數,即原函數導數的導數的導數,將原函數進行三次求導)。經過 grad ,JAX 支持反向模式和正向模式的求導,並且這兩種模式能夠任意組合成任何順序,具備必定靈活性。python

001.jpeg

JAX 面向的人羣

JAX 相對於 Tensorflow 和 Pytorch 仍是顯得比較原始(底層),許多東西還需本身去實現,可能你會問有必要本身去實現深度學習框架嗎? 本身去實現好處就是出現問題更好控制,對於不一樣任務定製化更強,因此 JAX 是面向研究人員,而不是開發人員,這一點想你們在開始瞭解這個庫須要清楚的一點。c++

003.jpeg

學習 JAX 動機

  • 最求性能,當用到既有機器學習框架遇到性能的瓶頸,而又對底層 c++ 和 GPU 結構原理了解不足,能夠考慮一下 JAX 來重構本身模型(本身尚未去嘗試)
  • 本身想要實現一個基於 python 的深度學習框架,能夠考慮一下 JAX

要完成一個大規模數據算法

  • 硬件加速
  • 自動求導來進行優化運算
  • 融合操做 ,例如 np.sum((preds - targets) ** 2)
  • 並行處理數據和計算

006.png

JAX 安裝

pip install --upgrade jax jaxlib
複製代碼

安裝 GPUshell

pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
複製代碼

009.jpeg

JAX 能夠當作運行在 GPU 和 TPU 上 Numpy

Numpy 是一個科學計算庫,就是在今天,就是用 tensorflow 和 pytorch 這樣煊赫一時的深度學習庫去實現模型也好、網絡也好,也少不了幾行用到 numpy 的代碼。可見 numpy 的重要性,可是 numpy 誕生時候並無大量使用 GPU 來支持運算,因此 numpy 的程序是沒法跑在 GPU,可是 JAX 其實並非 numpy,只是借鑑 numpy API,讓開發人員用起來感受在用 Numpy,無差異地使用 numpy。api

import numpy as np
複製代碼

引入 numpy ,爲了方便使用爲引入 numpy 起一個別名 np。數組

import numpy as np
x = np.random.rand(2000,2000)
print(x)
複製代碼
[[0.56745022 0.4247945  0.32374621 ... 0.72424614 0.31471484 0.75709393]
     [0.76504917 0.41393967 0.1195595  ... 0.27311255 0.36763284 0.39811399]
     [0.30034904 0.8224698  0.0160814  ... 0.75720634 0.72237672 0.09741124]
     ...
     [0.14822982 0.918704   0.22328525 ... 0.67143212 0.91682163 0.65214596]
     [0.25847224 0.7675988  0.64836721 ... 0.19096599 0.89869396 0.22051008]
     [0.23031364 0.60925244 0.72548038 ... 0.63396252 0.13415147 0.0674989 ]]
複製代碼
2 * x
複製代碼

這樣直觀進行矩陣運算,例如給 x 每一個元素都乘以 2 能夠用上面這樣直觀操做,而無需遍歷矩陣每一個元素markdown

array([[1.13490044, 0.849589  , 0.64749241, ..., 1.44849228, 0.62942968,
            1.51418785],
           [1.53009834, 0.82787934, 0.239119  , ..., 0.54622511, 0.73526569,
            0.79622798],
           [0.60069808, 1.6449396 , 0.03216279, ..., 1.51441268, 1.44475343,
            0.19482249],
           ...,
           [0.29645964, 1.83740799, 0.4465705 , ..., 1.34286423, 1.83364326,
            1.30429192],
           [0.51694448, 1.5351976 , 1.29673443, ..., 0.38193199, 1.79738792,
            0.44102015],
           [0.46062729, 1.21850487, 1.45096075, ..., 1.26792504, 0.26830294,
            0.1349978 ]])

複製代碼
np.sin(x)
複製代碼

對於一些複雜的運算例如 np.sin numpy 也應付自如。網絡

array([[0.53748363, 0.41213356, 0.31812038, ..., 0.66257099, 0.30954533,
            0.68681211],
           [0.69257247, 0.40221938, 0.11927486, ..., 0.26972993, 0.35940746,
            0.38768052],
           [0.29585364, 0.73282855, 0.0160807 , ..., 0.68689382, 0.66116964,
            0.09725726],
           ...,
           [0.14768759, 0.79481581, 0.2214345 , ..., 0.62210787, 0.79367208,
            0.60689338],
           [0.25560384, 0.69440939, 0.60388576, ..., 0.18980742, 0.78251439,
            0.21872738],
           [0.2282829 , 0.57225456, 0.66349493, ..., 0.59234195, 0.13374945,
            0.06744766]])
複製代碼
x - x.mean(0)
複製代碼
array([[ 0.05966959, -0.07397188, -0.18537367, ...,  0.21733322,
            -0.18467283,  0.25997255],
           [ 0.25726854, -0.08482671, -0.38956037, ..., -0.23380037,
            -0.13175483, -0.09900739],
           [-0.20743159,  0.32370341, -0.49303848, ...,  0.25029342,
             0.22298905, -0.39971013],
           ...,
           [-0.35955081,  0.41993761, -0.28583463, ...,  0.1645192 ,
             0.41743396,  0.15502459],
           [-0.24930839,  0.26883241,  0.13924734, ..., -0.31594693,
             0.39930629, -0.2766113 ],
           [-0.27746699,  0.11048605,  0.2163605 , ...,  0.1270496 ,
            -0.3652362 , -0.42962248]])

複製代碼
np.dot(x,x)
複製代碼

矩陣間點乘也十分方便閉包

array([[499.08919102, 490.98247709, 495.18751355, ..., 498.40635521,
            494.50937914, 485.34695773],
           [510.29685902, 499.95239357, 511.85978277, ..., 509.82817989,
            495.05226925, 507.41925595],
           [502.82328413, 501.8213885 , 506.67580735, ..., 508.35889233,
            492.64972834, 493.06081799],
           ...,
           [502.20453325, 496.38140482, 508.98725444, ..., 505.05666502,
            490.64576912, 491.95629717],
           [515.66634283, 498.26014692, 516.70676734, ..., 508.06152946,
            506.435225  , 500.36645682],
           [509.67692906, 502.64662385, 509.47906271, ..., 509.0583251 ,
            505.48856182, 493.5220343 ]])
複製代碼

接下來咱們看一看 jax 的 numpy 模塊提供方法相似於 numpy 的方法,咱們對比去上面 numpy 操做都用 jax.numpy 去實現一遍。

import jax.numpy as jnp
複製代碼
y = jnp.array(x)
複製代碼
y
複製代碼

將 numpy 對象來 DeviceArray

DeviceArray([[0.5674502 , 0.4247945 , 0.3237462 , ..., 0.72424614,
                  0.31471485, 0.7570939 ],
                 [0.76504916, 0.41393968, 0.1195595 , ..., 0.27311257,
                  0.36763284, 0.398114  ],
                 [0.30034903, 0.8224698 , 0.0160814 , ..., 0.7572063 ,
                  0.7223767 , 0.09741125],
                 ...,
                 [0.14822982, 0.918704  , 0.22328524, ..., 0.67143214,
                  0.9168216 , 0.652146  ],
                 [0.25847223, 0.7675988 , 0.6483672 , ..., 0.190966  ,
                  0.898694  , 0.22051008],
                 [0.23031364, 0.60925245, 0.7254804 , ..., 0.6339625 ,
                  0.13415147, 0.0674989 ]], dtype=float32)
複製代碼
2 * y
複製代碼
DeviceArray([[1.1349005 , 0.849589  , 0.6474924 , ..., 1.4484923 ,
                  0.6294297 , 1.5141878 ],
                 [1.5300983 , 0.82787937, 0.23911901, ..., 0.54622513,
                  0.7352657 , 0.796228  ],
                 [0.60069805, 1.6449395 , 0.03216279, ..., 1.5144126 ,
                  1.4447534 , 0.19482249],
                 ...,
                 [0.29645965, 1.837408  , 0.4465705 , ..., 1.3428643 ,
                  1.8336432 , 1.304292  ],
                 [0.51694447, 1.5351976 , 1.2967345 , ..., 0.381932  ,
                  1.797388  , 0.44102016],
                 [0.4606273 , 1.2185049 , 1.4509608 , ..., 1.267925  ,
                  0.26830295, 0.1349978 ]], dtype=float32)
複製代碼
jnp.sin(y)
複製代碼
DeviceArray([[0.53748363, 0.41213354, 0.31812036, ..., 0.662571  ,
                  0.30954534, 0.6868121 ],
                 [0.6925725 , 0.40221938, 0.11927487, ..., 0.26972994,
                  0.35940745, 0.38768053],
                 [0.2958536 , 0.73282856, 0.0160807 , ..., 0.6868938 ,
                  0.66116965, 0.09725726],
                 ...,
                 [0.1476876 , 0.79481584, 0.2214345 , ..., 0.6221079 ,
                  0.7936721 , 0.6068934 ],
                 [0.25560382, 0.69440943, 0.60388577, ..., 0.18980742,
                  0.78251445, 0.21872738],
                 [0.2282829 , 0.5722546 , 0.66349494, ..., 0.59234196,
                  0.13374946, 0.06744765]], dtype=float32)
複製代碼
y - y.mean(0)
複製代碼
DeviceArray([[ 0.05966955, -0.0739719 , -0.18537366, ...,  0.2173332 ,
                  -0.18467283,  0.2599725 ],
                 [ 0.2572685 , -0.08482671, -0.38956037, ..., -0.23380038,
                  -0.13175485, -0.0990074 ],
                 [-0.20743164,  0.32370338, -0.49303848, ...,  0.25029337,
                   0.22298902, -0.39971015],
                 ...,
                 [-0.35955083,  0.41993758, -0.2858346 , ...,  0.16451919,
                   0.41743392,  0.15502459],
                 [-0.24930844,  0.26883242,  0.13924736, ..., -0.31594694,
                   0.3993063 , -0.27661133],
                 [-0.277467  ,  0.11048606,  0.21636051, ...,  0.12704957,
                  -0.36523622, -0.4296225 ]], dtype=float32)
複製代碼
jnp.dot(y,y)
複製代碼
DeviceArray([[499.08923, 490.98248, 495.18756, ..., 498.4064 , 494.50937,
                  485.347  ],
                 [510.2968 , 499.95236, 511.85983, ..., 509.8281 , 495.0523 ,
                  507.4191 ],
                 [502.82324, 501.82147, 506.67572, ..., 508.35886, 492.64972,
                  493.06076],
                 ...,
                 [502.20465, 496.3814 , 508.9873 , ..., 505.0567 , 490.6458 ,
                  491.95618],
                 [515.66626, 498.2601 , 516.70667, ..., 508.06168, 506.43524,
                  500.3665 ],
                 [509.67685, 502.64664, 509.47913, ..., 509.0583 , 505.48856,
                  493.52206]], dtype=float32)

複製代碼
%timeit np.dot(x,x)
%timeit jnp.dot(y,y)
複製代碼
47.2 ms ± 5.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    2.16 ms ± 21.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
複製代碼

從執行兩個方法耗時對比來看,jpn.dot 仍是具備很大優點。

007.jpeg

JIT compilation(即便編譯)

爲了利用 XLA 的強大功能,將代碼編譯到 XLA 內核中。這就是 jit 發揮做用的地方。要使用 XLA 和 jit,可使用 jit() 函數或 @jit 註釋。

def f(x):
    for i in range(10):
        x -= 0.1 * x
    return x
複製代碼

這裏咱們定義函數f(x),函數自己並無實際意義,旨在說明 JIT compilation

f(x)
複製代碼

咱們能夠用 numpy 來運行執行對矩陣運算

array([[0.19785766, 0.14811668, 0.11288332, ..., 0.25252901, 0.10973428,
            0.26398233],
           [0.26675615, 0.14433184, 0.04168782, ..., 0.09522846, 0.12818565,
            0.13881376],
           [0.10472524, 0.28677749, 0.00560724, ..., 0.26402153, 0.25187719,
            0.0339652 ],
           ...,
           [0.05168454, 0.32033228, 0.07785475, ..., 0.2341139 , 0.31967594,
            0.22738924],
           [0.0901237 , 0.26764515, 0.22607167, ..., 0.06658573, 0.31335521,
            0.07688711],
           [0.0803054 , 0.21243319, 0.25295937, ..., 0.22104906, 0.04677572,
            0.02353541]])

複製代碼
f(y)
複製代碼

能夠用 jax.numpy 來執行這一些對矩陣的操做。

DeviceArray([[0.19785768, 0.1481167 , 0.11288333, ..., 0.25252903,
                  0.10973427, 0.26398236],
                 [0.26675615, 0.14433186, 0.04168782, ..., 0.09522847,
                  0.12818564, 0.13881375],
                 [0.10472523, 0.2867775 , 0.00560724, ..., 0.26402152,
                  0.2518772 , 0.0339652 ],
                 ...,
                 [0.05168454, 0.32033232, 0.07785475, ..., 0.23411393,
                  0.31967595, 0.22738926],
                 [0.09012369, 0.26764515, 0.22607167, ..., 0.06658573,
                  0.31335524, 0.07688711],
                 [0.08030539, 0.21243319, 0.25295934, ..., 0.22104907,
                  0.04677573, 0.02353541]], dtype=float32)

複製代碼

這是咱們來計算一些 f(y) 執行耗時,由於是同步執行,因此事件上看上去比較長,接下來咱們來用 JIT 來執行這個函數

%timeit f(y)
複製代碼
3.42 ms ± 31.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
複製代碼

在使用以前 JIT 以前,咱們須要引入 jit 包,使用起來也比較方便,用 jit 對函數 f 進行包裹一下就獲得 JIT

from jax import jit
g = jit(f)
複製代碼
%timeit g(y)
複製代碼
88.2 µs ± 560 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
複製代碼

包含多個 numpy 運算的函數能夠經過jax.jit()進行 just-in-time編譯,變成一個單一的 CUDA 程序後來執行,進一步加快運算速度。

自動微分

經過 grad() 函數自動微分,這對深度學習應用很是有用,這樣就能夠很容易地運行反向傳播。在深度學習咱們經過梯度去更新參數,因此自動求導是深度學習框架實現的重點也是難點。

def f(x):
    return x * jnp.sin(x)
複製代碼

這裏定義函數

f ( x ) = x sin ( x ) f(x) = x \sin(x)
f(3)
複製代碼
DeviceArray(0.42336, dtype=float32)
複製代碼

對於這個函數求導,咱們鏈式法則和經常使用函數求導能夠獲得以下

f ( x ) = sin ( x ) + x cos ( x ) f^{\prime}(x) = \sin(x) + x \cos(x)
def grad_f(x):
    return jnp.sin(x) + x * jnp.cos(x)
複製代碼
grad_f(3)
複製代碼
DeviceArray(-2.8288574, dtype=float32)
複製代碼

引入 jax 的 grad ,而後 grad 包裹 f 返回一個求導函數,自動求導幫助支持鏈式求導,將反向傳播對於程序設計變得簡答,其實深度學習框架難點就在於反向求導

from jax import grad
複製代碼
grad_f_jax = grad(f)
grad_f_jax(3.0)
複製代碼
DeviceArray(-2.8288574, dtype=float32)
複製代碼

向量化(vectorization)

vmap 是一種函數轉換,JAX 經過 vmap 變換提供了自動向量化算法,大大簡化了這種類型的計算,這使得研究人員在處理新算法時無需再去處理批量化的問題。示例以下:

def square(x):
    return jnp.sum(x ** 2)
複製代碼

定義 square 函數對向量每一個元素求平方,而後對這個向量進行求和,能夠想一下這是先對向量作 map 而後在作 reduce 的操做。

square(jnp.arange(100))
複製代碼
DeviceArray(328350, dtype=int32)
複製代碼

爲了解釋一下 vmap 咱們可看如何 numpy 來實現一下什麼是 vmap。JAX 的 API 中還有一個轉換,可能你尚未意識 vmap() 向量映射的好處。可能熟悉 map 函數式沿着數組軸來操做數組中每個元素,在 vmap 中不是把循環放在外面,而是把循環推到函數的原始操做中進行,從而得到更好的性能。

x = jnp.arange(100).reshape(10,10)
[square(row) for row in x]
複製代碼
[DeviceArray(285, dtype=int32),
     DeviceArray(2185, dtype=int32),
     DeviceArray(6085, dtype=int32),
     DeviceArray(11985, dtype=int32),
     DeviceArray(19885, dtype=int32),
     DeviceArray(29785, dtype=int32),
     DeviceArray(41685, dtype=int32),
     DeviceArray(55585, dtype=int32),
     DeviceArray(71485, dtype=int32),
     DeviceArray(89385, dtype=int32)]
複製代碼
from jax import vmap
vmap(square)
複製代碼
<function __main__.square(x)>
複製代碼
vmap(square)(x)
複製代碼
DeviceArray([  285,  2185,  6085, 11985, 19885, 29785, 41685, 55585,
                 71485, 89385], dtype=int32)
複製代碼
相關文章
相關標籤/搜索