這是我參與8月更文挑戰的第10天,活動詳情查看:8月更文挑戰html
JAX 的前身是 Autograd ,也就是說 JAX 是 Autograd 升級版本,JAX 能夠對 Python 和 NumPy 程序進行自動微分。能夠經過 Python的大量特徵子集進行區分,包括循環、分支、遞歸和閉包語句進行自動求導,也能夠求三階導數(三階導數是由原函數導數的導數的導數。 所謂三階導數,即原函數導數的導數的導數,將原函數進行三次求導)。經過 grad ,JAX 支持反向模式和正向模式的求導,並且這兩種模式能夠任意組合成任何順序,具備必定靈活性。python
JAX 相對於 Tensorflow 和 Pytorch 仍是顯得比較原始(底層),許多東西還需本身去實現,可能你會問有必要本身去實現深度學習框架嗎? 本身去實現好處就是出現問題更好控制,對於不一樣任務定製化更強,因此 JAX 是面向研究人員,而不是開發人員,這一點想你們在開始瞭解這個庫須要清楚的一點。c++
要完成一個大規模數據算法
pip install --upgrade jax jaxlib
複製代碼
安裝 GPUshell
pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
複製代碼
Numpy 是一個科學計算庫,就是在今天,就是用 tensorflow 和 pytorch 這樣煊赫一時的深度學習庫去實現模型也好、網絡也好,也少不了幾行用到 numpy 的代碼。可見 numpy 的重要性,可是 numpy 誕生時候並無大量使用 GPU 來支持運算,因此 numpy 的程序是沒法跑在 GPU,可是 JAX 其實並非 numpy,只是借鑑 numpy API,讓開發人員用起來感受在用 Numpy,無差異地使用 numpy。api
import numpy as np
複製代碼
引入 numpy ,爲了方便使用爲引入 numpy 起一個別名 np。數組
import numpy as np
x = np.random.rand(2000,2000)
print(x)
複製代碼
[[0.56745022 0.4247945 0.32374621 ... 0.72424614 0.31471484 0.75709393]
[0.76504917 0.41393967 0.1195595 ... 0.27311255 0.36763284 0.39811399]
[0.30034904 0.8224698 0.0160814 ... 0.75720634 0.72237672 0.09741124]
...
[0.14822982 0.918704 0.22328525 ... 0.67143212 0.91682163 0.65214596]
[0.25847224 0.7675988 0.64836721 ... 0.19096599 0.89869396 0.22051008]
[0.23031364 0.60925244 0.72548038 ... 0.63396252 0.13415147 0.0674989 ]]
複製代碼
2 * x
複製代碼
這樣直觀進行矩陣運算,例如給 x 每一個元素都乘以 2 能夠用上面這樣直觀操做,而無需遍歷矩陣每一個元素markdown
array([[1.13490044, 0.849589 , 0.64749241, ..., 1.44849228, 0.62942968,
1.51418785],
[1.53009834, 0.82787934, 0.239119 , ..., 0.54622511, 0.73526569,
0.79622798],
[0.60069808, 1.6449396 , 0.03216279, ..., 1.51441268, 1.44475343,
0.19482249],
...,
[0.29645964, 1.83740799, 0.4465705 , ..., 1.34286423, 1.83364326,
1.30429192],
[0.51694448, 1.5351976 , 1.29673443, ..., 0.38193199, 1.79738792,
0.44102015],
[0.46062729, 1.21850487, 1.45096075, ..., 1.26792504, 0.26830294,
0.1349978 ]])
複製代碼
np.sin(x)
複製代碼
對於一些複雜的運算例如 np.sin
numpy 也應付自如。網絡
array([[0.53748363, 0.41213356, 0.31812038, ..., 0.66257099, 0.30954533,
0.68681211],
[0.69257247, 0.40221938, 0.11927486, ..., 0.26972993, 0.35940746,
0.38768052],
[0.29585364, 0.73282855, 0.0160807 , ..., 0.68689382, 0.66116964,
0.09725726],
...,
[0.14768759, 0.79481581, 0.2214345 , ..., 0.62210787, 0.79367208,
0.60689338],
[0.25560384, 0.69440939, 0.60388576, ..., 0.18980742, 0.78251439,
0.21872738],
[0.2282829 , 0.57225456, 0.66349493, ..., 0.59234195, 0.13374945,
0.06744766]])
複製代碼
x - x.mean(0)
複製代碼
array([[ 0.05966959, -0.07397188, -0.18537367, ..., 0.21733322,
-0.18467283, 0.25997255],
[ 0.25726854, -0.08482671, -0.38956037, ..., -0.23380037,
-0.13175483, -0.09900739],
[-0.20743159, 0.32370341, -0.49303848, ..., 0.25029342,
0.22298905, -0.39971013],
...,
[-0.35955081, 0.41993761, -0.28583463, ..., 0.1645192 ,
0.41743396, 0.15502459],
[-0.24930839, 0.26883241, 0.13924734, ..., -0.31594693,
0.39930629, -0.2766113 ],
[-0.27746699, 0.11048605, 0.2163605 , ..., 0.1270496 ,
-0.3652362 , -0.42962248]])
複製代碼
np.dot(x,x)
複製代碼
矩陣間點乘也十分方便閉包
array([[499.08919102, 490.98247709, 495.18751355, ..., 498.40635521,
494.50937914, 485.34695773],
[510.29685902, 499.95239357, 511.85978277, ..., 509.82817989,
495.05226925, 507.41925595],
[502.82328413, 501.8213885 , 506.67580735, ..., 508.35889233,
492.64972834, 493.06081799],
...,
[502.20453325, 496.38140482, 508.98725444, ..., 505.05666502,
490.64576912, 491.95629717],
[515.66634283, 498.26014692, 516.70676734, ..., 508.06152946,
506.435225 , 500.36645682],
[509.67692906, 502.64662385, 509.47906271, ..., 509.0583251 ,
505.48856182, 493.5220343 ]])
複製代碼
接下來咱們看一看 jax 的 numpy 模塊提供方法相似於 numpy 的方法,咱們對比去上面 numpy 操做都用 jax.numpy
去實現一遍。
import jax.numpy as jnp
複製代碼
y = jnp.array(x)
複製代碼
y
複製代碼
將 numpy 對象來 DeviceArray
DeviceArray([[0.5674502 , 0.4247945 , 0.3237462 , ..., 0.72424614,
0.31471485, 0.7570939 ],
[0.76504916, 0.41393968, 0.1195595 , ..., 0.27311257,
0.36763284, 0.398114 ],
[0.30034903, 0.8224698 , 0.0160814 , ..., 0.7572063 ,
0.7223767 , 0.09741125],
...,
[0.14822982, 0.918704 , 0.22328524, ..., 0.67143214,
0.9168216 , 0.652146 ],
[0.25847223, 0.7675988 , 0.6483672 , ..., 0.190966 ,
0.898694 , 0.22051008],
[0.23031364, 0.60925245, 0.7254804 , ..., 0.6339625 ,
0.13415147, 0.0674989 ]], dtype=float32)
複製代碼
2 * y
複製代碼
DeviceArray([[1.1349005 , 0.849589 , 0.6474924 , ..., 1.4484923 ,
0.6294297 , 1.5141878 ],
[1.5300983 , 0.82787937, 0.23911901, ..., 0.54622513,
0.7352657 , 0.796228 ],
[0.60069805, 1.6449395 , 0.03216279, ..., 1.5144126 ,
1.4447534 , 0.19482249],
...,
[0.29645965, 1.837408 , 0.4465705 , ..., 1.3428643 ,
1.8336432 , 1.304292 ],
[0.51694447, 1.5351976 , 1.2967345 , ..., 0.381932 ,
1.797388 , 0.44102016],
[0.4606273 , 1.2185049 , 1.4509608 , ..., 1.267925 ,
0.26830295, 0.1349978 ]], dtype=float32)
複製代碼
jnp.sin(y)
複製代碼
DeviceArray([[0.53748363, 0.41213354, 0.31812036, ..., 0.662571 ,
0.30954534, 0.6868121 ],
[0.6925725 , 0.40221938, 0.11927487, ..., 0.26972994,
0.35940745, 0.38768053],
[0.2958536 , 0.73282856, 0.0160807 , ..., 0.6868938 ,
0.66116965, 0.09725726],
...,
[0.1476876 , 0.79481584, 0.2214345 , ..., 0.6221079 ,
0.7936721 , 0.6068934 ],
[0.25560382, 0.69440943, 0.60388577, ..., 0.18980742,
0.78251445, 0.21872738],
[0.2282829 , 0.5722546 , 0.66349494, ..., 0.59234196,
0.13374946, 0.06744765]], dtype=float32)
複製代碼
y - y.mean(0)
複製代碼
DeviceArray([[ 0.05966955, -0.0739719 , -0.18537366, ..., 0.2173332 ,
-0.18467283, 0.2599725 ],
[ 0.2572685 , -0.08482671, -0.38956037, ..., -0.23380038,
-0.13175485, -0.0990074 ],
[-0.20743164, 0.32370338, -0.49303848, ..., 0.25029337,
0.22298902, -0.39971015],
...,
[-0.35955083, 0.41993758, -0.2858346 , ..., 0.16451919,
0.41743392, 0.15502459],
[-0.24930844, 0.26883242, 0.13924736, ..., -0.31594694,
0.3993063 , -0.27661133],
[-0.277467 , 0.11048606, 0.21636051, ..., 0.12704957,
-0.36523622, -0.4296225 ]], dtype=float32)
複製代碼
jnp.dot(y,y)
複製代碼
DeviceArray([[499.08923, 490.98248, 495.18756, ..., 498.4064 , 494.50937,
485.347 ],
[510.2968 , 499.95236, 511.85983, ..., 509.8281 , 495.0523 ,
507.4191 ],
[502.82324, 501.82147, 506.67572, ..., 508.35886, 492.64972,
493.06076],
...,
[502.20465, 496.3814 , 508.9873 , ..., 505.0567 , 490.6458 ,
491.95618],
[515.66626, 498.2601 , 516.70667, ..., 508.06168, 506.43524,
500.3665 ],
[509.67685, 502.64664, 509.47913, ..., 509.0583 , 505.48856,
493.52206]], dtype=float32)
複製代碼
%timeit np.dot(x,x)
%timeit jnp.dot(y,y)
複製代碼
47.2 ms ± 5.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
2.16 ms ± 21.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
複製代碼
從執行兩個方法耗時對比來看,jpn.dot
仍是具備很大優點。
爲了利用 XLA 的強大功能,將代碼編譯到 XLA 內核中。這就是 jit 發揮做用的地方。要使用 XLA 和 jit,可使用 jit() 函數或 @jit 註釋。
def f(x):
for i in range(10):
x -= 0.1 * x
return x
複製代碼
這裏咱們定義函數f(x)
,函數自己並無實際意義,旨在說明 JIT compilation
f(x)
複製代碼
咱們能夠用 numpy 來運行執行對矩陣運算
array([[0.19785766, 0.14811668, 0.11288332, ..., 0.25252901, 0.10973428,
0.26398233],
[0.26675615, 0.14433184, 0.04168782, ..., 0.09522846, 0.12818565,
0.13881376],
[0.10472524, 0.28677749, 0.00560724, ..., 0.26402153, 0.25187719,
0.0339652 ],
...,
[0.05168454, 0.32033228, 0.07785475, ..., 0.2341139 , 0.31967594,
0.22738924],
[0.0901237 , 0.26764515, 0.22607167, ..., 0.06658573, 0.31335521,
0.07688711],
[0.0803054 , 0.21243319, 0.25295937, ..., 0.22104906, 0.04677572,
0.02353541]])
複製代碼
f(y)
複製代碼
能夠用 jax.numpy 來執行這一些對矩陣的操做。
DeviceArray([[0.19785768, 0.1481167 , 0.11288333, ..., 0.25252903,
0.10973427, 0.26398236],
[0.26675615, 0.14433186, 0.04168782, ..., 0.09522847,
0.12818564, 0.13881375],
[0.10472523, 0.2867775 , 0.00560724, ..., 0.26402152,
0.2518772 , 0.0339652 ],
...,
[0.05168454, 0.32033232, 0.07785475, ..., 0.23411393,
0.31967595, 0.22738926],
[0.09012369, 0.26764515, 0.22607167, ..., 0.06658573,
0.31335524, 0.07688711],
[0.08030539, 0.21243319, 0.25295934, ..., 0.22104907,
0.04677573, 0.02353541]], dtype=float32)
複製代碼
這是咱們來計算一些 f(y)
執行耗時,由於是同步執行,因此事件上看上去比較長,接下來咱們來用 JIT 來執行這個函數
%timeit f(y)
複製代碼
3.42 ms ± 31.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
複製代碼
在使用以前 JIT 以前,咱們須要引入 jit 包,使用起來也比較方便,用 jit
對函數 f
進行包裹一下就獲得 JIT
from jax import jit
g = jit(f)
複製代碼
%timeit g(y)
複製代碼
88.2 µs ± 560 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
複製代碼
包含多個 numpy 運算的函數能夠經過jax.jit()
進行 just-in-time
編譯,變成一個單一的 CUDA 程序後來執行,進一步加快運算速度。
經過 grad() 函數自動微分,這對深度學習應用很是有用,這樣就能夠很容易地運行反向傳播。在深度學習咱們經過梯度去更新參數,因此自動求導是深度學習框架實現的重點也是難點。
def f(x):
return x * jnp.sin(x)
複製代碼
這裏定義函數
f(3)
複製代碼
DeviceArray(0.42336, dtype=float32)
複製代碼
對於這個函數求導,咱們鏈式法則和經常使用函數求導能夠獲得以下
def grad_f(x):
return jnp.sin(x) + x * jnp.cos(x)
複製代碼
grad_f(3)
複製代碼
DeviceArray(-2.8288574, dtype=float32)
複製代碼
引入 jax 的 grad ,而後 grad 包裹 f 返回一個求導函數,自動求導幫助支持鏈式求導,將反向傳播對於程序設計變得簡答,其實深度學習框架難點就在於反向求導
from jax import grad
複製代碼
grad_f_jax = grad(f)
grad_f_jax(3.0)
複製代碼
DeviceArray(-2.8288574, dtype=float32)
複製代碼
vmap 是一種函數轉換,JAX 經過 vmap 變換提供了自動向量化算法,大大簡化了這種類型的計算,這使得研究人員在處理新算法時無需再去處理批量化的問題。示例以下:
def square(x):
return jnp.sum(x ** 2)
複製代碼
定義 square
函數對向量每一個元素求平方,而後對這個向量進行求和,能夠想一下這是先對向量作 map 而後在作 reduce 的操做。
square(jnp.arange(100))
複製代碼
DeviceArray(328350, dtype=int32)
複製代碼
爲了解釋一下 vmap 咱們可看如何 numpy 來實現一下什麼是 vmap。JAX 的 API 中還有一個轉換,可能你尚未意識 vmap() 向量映射的好處。可能熟悉 map 函數式沿着數組軸來操做數組中每個元素,在 vmap 中不是把循環放在外面,而是把循環推到函數的原始操做中進行,從而得到更好的性能。
x = jnp.arange(100).reshape(10,10)
[square(row) for row in x]
複製代碼
[DeviceArray(285, dtype=int32),
DeviceArray(2185, dtype=int32),
DeviceArray(6085, dtype=int32),
DeviceArray(11985, dtype=int32),
DeviceArray(19885, dtype=int32),
DeviceArray(29785, dtype=int32),
DeviceArray(41685, dtype=int32),
DeviceArray(55585, dtype=int32),
DeviceArray(71485, dtype=int32),
DeviceArray(89385, dtype=int32)]
複製代碼
from jax import vmap
vmap(square)
複製代碼
<function __main__.square(x)>
複製代碼
vmap(square)(x)
複製代碼
DeviceArray([ 285, 2185, 6085, 11985, 19885, 29785, 41685, 55585,
71485, 89385], dtype=int32)
複製代碼