衆所周知,儘管基於Attention機制的Transformer類模型有着良好的並行性能,但它的空間和時間複雜度都是
O(n2)級別的,
n是序列長度,因此當
n比較大時Transformer模型的計算量難以承受。近來,也有很多工做致力於下降Transformer模型的計算量,好比模型剪枝、量化、蒸餾等精簡技術,又或者修改Attention結構,使得其複雜度能下降到
O(nlogn)甚至
O(n)php
論文《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》當中提到一種線性化Attention(Linear Attention)的方法,由此引起了個人興趣,繼而閱讀了一些相關博客,有一些不錯的收穫,最後將本身對線性化Attention的理解彙總在此文中html
Attention
當前最流行的Attention機制當屬Scaled-Dot Attention,即python
Attention(Q,K,V)=softmax(QK⊤)V(1)
這裏的
Q∈Rn×dk,K∈Rm×dk,V∈Rm×dv,簡單起見我就沒顯示的寫出Attention的縮放因子
d
1了。本文咱們主要關心Self Attention的場景,因此爲了介紹上的方便,統一設
Q,K,V∈Rn×dmarkdown
摘掉Softmax
讀者也許想不到,制約Attention性能的關鍵因素,實際上是定義裏邊的Softmax!事實上,簡單地推導一下就能夠獲得這個結論。
QKT這一步咱們獲得一個
n×n的矩陣,以後還要作一個Softmax網絡
對一個
1×n的行向量進行Softmax,時間複雜度是
O(n),可是對一個
n×n矩陣的每一行作一個Softmax,時間複雜度就是
O(n2)app
若是沒有Softmax,那麼Attention的公式就變爲三個矩陣連乘
QK⊤V,而矩陣乘法是知足結合率的,因此咱們能夠先算
K⊤V,獲得一個
d×d的矩陣(這一步的時間複雜度是
O(d2n)),而後再用
Q左乘它(這一步的時間複雜度是
O(d2n)),因爲
d≪n,因此這樣算大體的時間複雜度只是
O(n)ide
對於BERT base來講,
d=64而不是768,why?由於768其實是經過Multi-Head拼接獲得的,而每一個head的
d=64svg
也就是說,去掉Softmax的Attention複雜度能夠降到最理想的線性級別
O(n)!這顯然就是咱們的終極追求:Linear Attention函數
通常的定義
問題是,直接去掉Softmax還能算是Attention嗎?他還能有標準的Attention的效果嗎?爲了回答這個問題,咱們先將Scaled-Dot Attention的定義等價的改寫爲(本文的向量都是列向量)oop
Attention(Q,K,V)i=j=1∑neqi⊤kjj=1∑neqi⊤kjvj(2)
這裏稍微解釋下,首先咱們知道
Q,K∈Rn×d,令
M=Q×K⊤,由矩陣乘法法則可知,
M的第一行是由
Q的第一行乘以
K⊤的全部列獲得的
Attention(Q,K,V)i表示最終輸出結果矩陣的第
i行
qi⊤表示
Q∈Rn×d矩陣的第
i行(行向量)
kj表示
K⊤∈Rd×n矩陣的第
j列(列向量)
vj表示
V⊤∈Rd×n矩陣的的第
j列(列向量)
因此,Scaled-Dot Attention其實就是以
eqi⊤kj爲權重對
vj作加權平均。因此咱們能夠提出一個Attention的通常化定義
Attention(Q,K,V)i=j=1∑nsim(qi,kj)j=1∑nsim(qi,kj)vj(3)
也就是把
eqi⊤kj換成
qi,ki的通常函數
sim(qi,kj),爲了保留Attention類似的分佈特性,咱們要求
sim(qi,kj)≥0恆成立。也就是說,咱們若是要定義新的Attention,必需要保留式(3)的形式,而且知足
sim(qi,kj)≥0
這種通常形式的Attention在CV中也被稱爲Non-Local網絡,出自論文《Non-local Neural Networks》
幾個例子
若是直接去掉Softmax,那麼就是
sim(qi,kj)=qi⊤kj,問題是內積沒法保證非負性,因此這還不是一個合理的選擇。下面咱們介紹幾種可取的方案
值得一提的是,下面介紹的這幾種Linear Attention,前兩種來自CV領域,第三種是蘇劍林大佬構思的(除了下面的介紹外,還有EMANet等CV領域對Attention的改進工做)
核函數形式
一個天然的想法是:若是
qi,kj的每一個元素都是非負的,那麼內積天然也是非負的。爲了完成這點,咱們能夠給
qi,kj各自加個激活函數
ϕ,φ,即
sim(qi,kj)=ϕ(qi)⊤φ(kj)(4)
其中
ϕ(⋅),φ(⋅)是值域非負的激活函數。本文開頭提到的論文《Transformers are RNNs》選擇的是
ϕ(x)=φ(x)=elu(x)+1,其中
elu(x)={xα(ex−1)if x>0if x<0
常見的
α取值爲
[0.1,0.3]
非要講故事的話,式(4)能夠聯想到"核方法",尤爲是
ϕ=φ時,
ϕ就至關於一個核函數,而
⟨ϕ(qi),ϕ(kj)⟩就是經過核函數所定義的內積。這方面的思考能夠參考論文《Transformer dissection: An unified understanding for transformer’s attention via the lens of kernel》,此處不作過多延伸
妙用Softmax
另外一篇更早的文章《Efficient Attention: Attention with Linear Complexities》則給出了一個更有意思的選擇。它留意到在
QK⊤中,
Q,K∈Rn×d,若是「
Q在
d那一維是歸一化的,而且
K在
n那一維是歸一化的」,那麼
QK⊤就是自動知足歸一化了,因此它給出的選擇是
Attention(Q,K,V)=softmax2(Q)softmax1(K)⊤V(5)
其中
softmax1、
softmax2分別表示在第一個
(n)、第二個維度
(d)進行Softmax運算。也就是說,這時候咱們是各自給
Q,K加Softmax,而不是算完
QK⊤以後再加Softmax
其實能夠證實這個形式也是式(4)的一個特例,此時對應於
ϕ(qi)=softmax(qi),φ(kj)=ekj,讀者能夠自行推導一下
蘇神的構思
在這裏,蘇神給出了一種構思。這個構思的出發點再也不是式(4),而是源於咱們對原始定義(2)的泰勒展開。由泰勒展開咱們有
eqi⊤kj≈1+qi⊤kj(6)
若是
qi⊤kj≥−1,那麼就能夠保證右端的非負性,從而可讓
sim(qi,kj)=1+qi⊤kj。到這裏讀者可能已經想到了,想要保證
qi⊤kj≥−1,只須要分別對
qi,kj作
l2歸一化。因此,蘇神最終提出的方案就是:
sim(qi,kj)=1+(∥qi∥qi)⊤(∥kj∥kj)(7)
若
x=[x1,x2,...,xn],則
∥x∥=x12+x22+⋅⋅⋅+xn2
這不一樣於式(4),但理論上它更加接近原始的Scaled-Dot Attention
實現
這裏主要是針對蘇神所提出的方法進行實現,可是因爲筆者本人水平有限,所以最終實現的代碼當中其實存在一些問題,主要是:
- 從測試結果來看,改進後的計算速度並無提高
- 沒法作到求和爲1
代碼實現主要是針對BERT的PyTorch實現這篇文章的代碼,更具體的說,其實僅修改了ScaledDotProductAttention
這個函數,所以下面只放出這部分代碼
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
Q = F.normalize(Q, dim=3)
K = F.normalize(K, dim=3)
M = (torch.ones(Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2]) + torch.matmul(Q, K.transpose(-1, -2)))
M_sum = torch.sum(M, dim=3)
M = M / M_sum.unsqueeze(3).repeat(1, 1, 1, M.shape[3])
attn = M.masked_fill(attn_mask, 0)
context = torch.matmul(attn, V)
return context
複製代碼
若是您有更好的實現方法,還望不吝賜教
Reference