深刻淺出PyTorch(算子篇)

Tensor

自從張量(Tensor)計算這個概念出現後,神經網絡的算法就能夠看做是一系列的張量計算。所謂的張量,它本來是個數學概念,表示各類向量或者數值之間的關係。PyTorch的張量(torch.Tensor)表示的是N維矩陣與一維數組的關係。html

http://web.mit.edu/~ezyang/Public/pytorch-internals.pdf

torch.Tensor的使用方法和numpy很類似(https://pytorch.org/...tensor-tutorial-py),二者惟一的區別在於torch.Tensor可使用GPU來計算,這就比用CPU的numpy要快不少。web

張量計算的種類有不少,好比加法、乘法、矩陣相乘、矩陣轉置等,這些計算被稱爲算子(Operator),它們是PyTorch的核心組件。算法

算子的backend通常是C/C++的拓展程序,PyTorch的backend是稱爲"ATen"的C/C++庫,ATen是"A Tensor"的縮寫。數組

Operator

PyTorch全部的Operator都定義在Declarations.cwrap和native_functions.yaml這兩個文件中,前者定義了從Torch那繼承來的legacy operator(aten/src/TH),後者定義的是native operator,是PyTorch的operator。網絡

相比於用C++開發的native code,legacy code是在PyTorch編譯時由gen.py根據Declarations.cwrap的內容動態生成的。所以,若是你想要trace這些code,須要先編譯PyTorch。ide

legacy code的開發要比native code複雜得多。若是能夠的話,建議你儘可能避開它們。函數

aten/src/ATen/Declarations.cwrap

MatMul

本文會以矩陣相乘--torch.matmul()爲例來分析PyTorch算子的工做流程。學習

我在深刻淺出全鏈接層(fully connected layer)中有講在GPU層面是如何進行矩陣相乘的。Nvidia、AMD等公司提供了優化好的線性代數計算庫--cuBLAS/rocBLAS/openBLAS,PyTorch只須要調用它們的API便可。優化

Figure 1: function flow of torch.matmul()

Figure 1是torch.matmul()在ATen中的function flow。能夠看到,這個flow可不短,這主要是由於不一樣類型的tensor(2d or Nd, batched gemm or not,with or without bias,cuda or cpu)的操做也不盡相同。spa

at::matmul()主要負責將Tensor轉換成cuBLAS須要的格式。前面說過,Tensor能夠是N維矩陣,若是tensor A是3d矩陣,tensor B是2d矩陣,就須要先將3d轉成2d;若是它們都是>=3d的矩陣,就要考慮batched matmul的狀況;若是bias=True,後續就應該交給at::addmm()來處理;總之,matmul要考慮的事情比想象中要多。

除此以外,不一樣的dtype、device和layout須要調用不一樣的操做函數,這部分工做交由c10::dispatcher來完成。

Dispatcher

dispatcher主要用於動態調用dtype、device以及layout等方法函數。用過numpy的都知道,np.array()的數據類型有:float32, float16,int8,int32,.... 若是你瞭解C++就會知道,這類程序最適合用模板(template)來實現。

很遺憾,因爲ATen有一部分operator是用C語言寫的(從Torch繼承過來),不支持模板功能,所以,就須要dispatcher這樣的動態調度器。

相似地,PyTorch的tensor不只能夠運行在GPU上,還能夠跑在CPU、mkldnn和xla等設備,Figure 1中的dispatcher4就根據tensor的device調用了mm的GPU實現。

layout是指tensor中元素的排布。通常來講,矩陣的排布都是緊湊型的,也就是strided layout。而那些有着大量0的稀疏矩陣,相應地就是sparse layout。

Figure 2: strided layout example

Figure 2是strided layout的演示實例,這裏建立了一個2行2列的矩陣a,它的數據實際存放在一維數組(a.storage)裏,2行2列只是這個數組的視圖。

stride充當了從數組到視圖的橋樑,好比,要打印第2行第2列的元素時,能夠經過公式:\(1 * stride(0) + 1 * stride(1)\)來計算該元素在數組中的索引。

除了dtype、device、layout以外,dispatcher還能夠用來調用legacy operator。好比說addmm這個operator,它的GPU實現就是經過dispatcher來跳轉到legacy::cuda::_th_addmm。

aten/src/ATen/native/native_functions.yaml

END

到此,就完成了對PyTorch算子的學習。若是你要學習其餘算子,能夠先從aten/src/ATen/native目錄的相關函數入手,從native_functions.yaml中找到dispatch目標函數,詳情能夠參考Figure 1。


更多精彩文章,歡迎掃碼關注下方的公衆號, 並訪問個人簡書博客:https://www.jianshu.com/u/c0fe8671254e

歡迎轉發至朋友圈,工做號轉載請後臺留言申請受權~

AI實戰:一個有料有深度的公衆號

相關文章
相關標籤/搜索