聊一聊 Numpy 的終結者 JAX

這是我參與8月更文挑戰的第11天,活動詳情查看:8月更文挑戰html

從根本上說,JAX 是一個庫,提供 API 相似 NumPy,主要用於編寫的數組操縱程序進行轉換。甚至有人認爲 JAX 能夠看作 Numpy v2,不只加快 Numpy 並且爲 Numpy 提供自動求導(grad)功能,讓咱們僅憑藉 JAX 就能夠去實現一個機器學習框架。python

022.png

接下來主要就是來解釋一下爲何說 JAX 提供 API 相似 NumPy,。如今,你能夠把 JAX 看做是在加速器上運行支持自動求導的 NumPy。編程

import jax
import jax.numpy as jnp

x = jnp.arange(10)
print(x)
複製代碼

若是你們熟悉或用過 numpy 寫過點東西,上面的代碼應該不會陌生,這也就是 JAX的魅力,能夠從 numpy 無縫過渡到 JAX 在於你不須要學習一個新的 API。能夠將之前用用 numpy 實現的代碼,能夠用 jnp 代替 np,程序也能夠運行起來,固然也有不一樣之處,隨後會介紹。在 jnp 是 DeviceArray 類型的變量,這也是 JAX 表示數組的方式。數組

咱們如今將計算兩個向量的點積,block_until_ready 在無需更改代碼在 GPU 的設備運行代碼,而不須要改變代碼。使用%timeit來檢查性能。markdown

技術細節:當一個 JAX 函數被調用時,相應的操做被派發到一個加速器上,經過是進行異步計算。所以,計算返回的數組不必定在函數返回時就被「填滿"。所以,若是不須要當即獲得結果,由於是異步計算,因此不會阻塞 Python 的執行。所以,除非設置 block_until_ready,不然咱們將只爲調度計時,而不是爲實際計算計時。參見 JAX 文檔中的異步調度數據結構

long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()
複製代碼
The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 5: 6.37 ms per loop
複製代碼

JAX 的第一次轉換:grad

JAX的一個基本特徵是容許轉換函數。最經常使用的轉換之一 是 jax.grad,接收一個用 Python 編寫的數值函數,並返回一個新的 Python 函數,計算原函數的梯度。定義一個函數sum_of_squares,接收一個數組並返回對數組每一個元素平方後求和。app

def sum_of_squares(x):
  return jnp.sum(x**2)
複製代碼

sum_of_squares應用 jax.grad將返回一個不一樣的函數,這個函數就是sum_of_squares 相對於其第一個參數 x 的梯度。框架

而後,將數組輸入這個求導函數來返回相對於數組中每一個元素的導數。機器學習

sum_of_squares_dx = jax.grad(sum_of_squares)

x = jnp.asarray([1.0, 2.0, 3.0, 4.0])

print(sum_of_squares(x))

print(sum_of_squares_dx(x))
複製代碼
0.0
[2. 4. 6. 8.]
複製代碼

你能夠經過類比向量微積分中的 n a b l a nabla 運算符爲 jax.grad,若是函數 f ( x ) f(x) 輸入給了 jax.grad ,也就等同於返回 n a b l a nabla 函數用於計算𝑓梯度的函數。異步

( f ) ( x i ) = f x i ( x i ) (\nabla f)(x_i) = \frac{\partial f}{\partial x_i}(x_i)

相似地,jax.grad(f) 是計算梯度的函數,因此 jax.grad(f)(x)fx 處的梯度。(和 \nabla 同樣,jax.grad只對有標量輸出的函數起做用,不然會引起錯誤)

這樣一來 JAX API 與其餘支持自動求導如 Tensorflow 和 PyTorch 深度學習框架就有很大的不一樣,在後者中,咱們可使用損失張量自己來計算梯度( 例如經過調用 loss.backward() 來計算梯度)。JAX API 直接與函數一塊兒工做,更接近於底層數學。一旦你習慣了這種作事方式,就會感受很天然:你在代碼中的損失函數確實是一個參數和數據的函數,你就像在數學中那樣找到它的梯度。

這種作事方式使得控制諸如對哪些變量進行微分的事情變得簡單明瞭。默認狀況下,jax.grad會找到與第一個參數有關的梯度。在下面的例子中,sum_squared_error_dx的結果將是sum_squared_error相對於x的梯度。

def sum_squared_error(x, y):
  return jnp.sum((x-y)**2)

sum_squared_error_dx = jax.grad(sum_squared_error)

y = jnp.asarray([1.1, 2.1, 3.1, 4.1])

print(sum_squared_error_dx(x, y))
複製代碼

若是須要計算不一樣參數(或幾個參數)的梯度,能夠設置 argnums 來實現。

[-0.20000005 -0.19999981 -0.19999981 -0.19999981]
複製代碼
jax.grad(sum_squared_error, argnums=(0, 1))(x, y)  # Find gradient wrt both x & y
複製代碼
(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),
 DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))
複製代碼

這是否意味着在進行機器學習時,模型須要用巨大的參數列表來編寫函數,每一個模型參數陣列都有一個參數?JAX 配備了將數組捆綁在稱爲 "pytrees " 的數據結構中的機制,jax.grad的使用是這樣的。

Value 和 Grad

jax.value_and_grad(sum_squared_error)(x, y)
複製代碼
(DeviceArray(0.03999995, dtype=float32),
 DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))
複製代碼

輔助數據

除了想要記錄數值以外,咱們還常常想要報告在計算損失函數時得到的一些中間結果。可是若是咱們試圖用普通的jax.grad來作這個,就會遇到麻煩。

def squared_error_with_aux(x, y):
  return sum_squared_error(x, y), x-y

jax.grad(squared_error_with_aux)(x, y)
複製代碼

上面代碼執行會報錯,還需在grad函數中設置一個參數。

jax.grad(squared_error_with_aux, has_aux=True)(x, y)
複製代碼

這是由於jax.grad只定義在標量函數上,轉換後獲得函數會返回一個元組。由於組員中包含一些輔助數據, 這就是has_aux的做用。

JAX 與 NumPy 不一樣之處

經過上面例子咱們已經發現 jax.numpy 在 API 設計上基本能夠說與 NumPy 的 API 保持一致。然而,並不是所有也有一些的區別。接下來咱們就 JAX 與 Numpy 不一樣之處給你們介紹一下。最重要的區別,就是 JAX 更偏向於函數式編程的風格,這是 Numpy 和 JAX 在某些點不只相同主要緣由。對函數式編程(FP)的介紹不在本指南的範圍以內。若是已經熟悉了 FP,那麼用起來 JAX 就會更加順手,由於 JAX 就是面向函數式編程設計的。

import numpy as np

x = np.array([1, 2, 3])

def in_place_modify(x):
  x[0] = 123
  return None

in_place_modify(x)
x
複製代碼

若是熟悉函數式編程,當看出輸出array([123, 2, 3])時,就會發現問題了,in_place_modify 作了一些側邊效應的事,在其內部更新 x 的值。由於在函數式編程中數據應該是 immutable(不可變),每次修改數據不是在源數據上進行修改,而是 copy 一份在進行修改。

in_place_modify(jnp.array(x)
複製代碼

有用的是,這個錯誤給指出了 JAX 經過 jax.ops.index_* ops 作是一個無反作用的方法。相似於不該該經過索引在原數組上進行的就地修改(in-place modification),而是建立一個新的數組並進行相應的修改。因此上面操做在 JAX 中會報錯

def jax_in_place_modify(x):
  return jax.ops.index_update(x, 0, 123)

y = jnp.array([1, 2, 3])
jax_in_place_modify(y)
複製代碼
DeviceArray([123,   2,   3], dtype=int32)
複製代碼

這時咱們再次查看 y 發現並無改變。

y #DeviceArray([1, 2, 3], dtype=int32)
複製代碼

Side-effect-free code is sometimes called functionally pure, or just pure.

無反作用的代碼有時被稱爲功能上的 pure,不是功能單一意思,而是不作一些更新應用狀態,或者 IO 等等其餘工做。

pure 版本的效率不是更低嗎?嚴格地說,是的。這是咱們不是在原有數據進行修改而是建立一個新的數組在其上進行修改。然而,JAX 計算在運行前一般會使用另外一個程序轉換,即 jax.jit 進行編譯。若是咱們在使用 jax.ops.index_update()對原數組進行 "就地 "修改後不使用,編譯器就能識別出實際上能夠編譯爲就地修改,從而最終獲得高效的代碼。

固然,有可能將有反作用的 Python 代碼和函數式支持存函數的 JAX 代碼混合在一塊兒,其實很難寫出或者幾乎作不到,寫出純函數式編程的程序,隨着你對 JAX 愈來愈熟悉,就會逐漸熟練知道何時該用 JAX,在後面有關這一點還會調到,暫時咱們就記住在 JAX 中避免發生側邊效用。

025.png

相關文章
相關標籤/搜索