去掉Attention的Softmax,複雜度降爲O(n)

衆所周知,儘管基於Attention機制的Transformer類模型有着良好的並行性能,但它的空間和時間複雜度都是 O ( n 2 ) \mathcal{O}(n^2) 級別的, n n 是序列長度,因此當 n n 比較大時Transformer模型的計算量難以承受。近來,也有很多工做致力於下降Transformer模型的計算量,好比模型剪枝、量化、蒸餾等精簡技術,又或者修改Attention結構,使得其複雜度能下降到 O ( n l o g n ) \mathcal{O}(nlog⁡n) 甚至 O ( n ) \mathcal{O}(n) php

論文《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》當中提到一種線性化Attention(Linear Attention)的方法,由此引起了個人興趣,繼而閱讀了一些相關博客,有一些不錯的收穫,最後將本身對線性化Attention的理解彙總在此文中html

Attention

當前最流行的Attention機制當屬Scaled-Dot Attention,即python

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K ) V (1) \begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\boldsymbol{Q}\boldsymbol{K}^{\top}\right)\boldsymbol{V}\tag{1}\end{aligned}

這裏的 Q R n × d k , K R m × d k , V R m × d v \boldsymbol{Q}\in \mathbb{R}^{n\times d_k}, \boldsymbol{K}\in \mathbb{R}^{m\times d_k}, \boldsymbol{V}\in \mathbb{R}^{m\times d_v} ,簡單起見我就沒顯示的寫出Attention的縮放因子 1 d \frac{1}{\sqrt{d}} 了。本文咱們主要關心Self Attention的場景,因此爲了介紹上的方便,統一設 Q , K , V R n × d \boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}\in \mathbb{R}^{n\times d} markdown

摘掉Softmax

讀者也許想不到,制約Attention性能的關鍵因素,實際上是定義裏邊的Softmax!事實上,簡單地推導一下就能夠獲得這個結論。 Q K T QK^T 這一步咱們獲得一個 n × n n\times n 的矩陣,以後還要作一個Softmax網絡

對一個 1 × n 1\times n 的行向量進行Softmax,時間複雜度是 O ( n ) O(n) ,可是對一個 n × n n\times n 矩陣的每一行作一個Softmax,時間複雜度就是 O ( n 2 ) O(n^2) app

若是沒有Softmax,那麼Attention的公式就變爲三個矩陣連乘 Q K V \boldsymbol{QK^{\top}V} ,而矩陣乘法是知足結合率的,因此咱們能夠先算 K V \boldsymbol{K^{\top}V} ,獲得一個 d × d d\times d 的矩陣(這一步的時間複雜度是 O ( d 2 n ) O(d^2n) ),而後再用 Q Q 左乘它(這一步的時間複雜度是 O ( d 2 n ) O(d^2n) ),因爲 d n d \ll n ,因此這樣算大體的時間複雜度只是 O ( n ) O(n) ide

對於BERT base來講, d = 64 d=64 而不是768,why?由於768其實是經過Multi-Head拼接獲得的,而每一個head的 d = 64 d=64 svg

也就是說,去掉Softmax的Attention複雜度能夠降到最理想的線性級別 O ( n ) \mathcal{O}(n) !這顯然就是咱們的終極追求:Linear Attention函數

通常的定義

問題是,直接去掉Softmax還能算是Attention嗎?他還能有標準的Attention的效果嗎?爲了回答這個問題,咱們先將Scaled-Dot Attention的定義等價的改寫爲(本文的向量都是列向量)oop

A t t e n t i o n ( Q , K , V ) i = j = 1 n e q i k j v j j = 1 n e q i k j (2) \begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\boldsymbol{v}_j}{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}}\tag{2}\end{aligned}

這裏稍微解釋下,首先咱們知道 Q , K R n × d \boldsymbol{Q},\boldsymbol{K}\in \mathbb{R}^{n\times d} ,令 M = Q × K \boldsymbol{M} = \boldsymbol{Q}\times \boldsymbol{K^{\top}} ,由矩陣乘法法則可知, M \boldsymbol{M} 的第一行是由 Q \boldsymbol{Q} 的第一行乘以 K \boldsymbol{K^{\top}} 的全部列獲得的

A t t e n t i o n ( Q , K , V ) i Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i 表示最終輸出結果矩陣的第 i i

q i \boldsymbol{q}_i^{\top} 表示 Q R n × d \boldsymbol{Q}\in \mathbb{R}^{n\times d} 矩陣的第 i i 行(行向量)

k j \boldsymbol{k}_j 表示 K R d × n \boldsymbol{K^{\top}}\in \mathbb{R}^{d\times n} 矩陣的第 j j 列(列向量)

v j \boldsymbol{v}_j 表示 V R d × n V^{\top}\in \mathbb{R}^{d\times n} 矩陣的的第 j j 列(列向量)

因此,Scaled-Dot Attention其實就是以 e q i k j e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} 爲權重對 v j \boldsymbol{v}_j 作加權平均。因此咱們能夠提出一個Attention的通常化定義

A t t e n t i o n ( Q , K , V ) i = j = 1 n sim ( q i , k j ) v j j = 1 n sim ( q i , k j ) (3) \begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)}\tag{3}\end{aligned}

也就是把 e q i k j e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} 換成 q i , k i \boldsymbol{q}_i,\boldsymbol{k}_i 的通常函數 sim ( q i , k j ) \text{sim}(\boldsymbol{q}_i,\boldsymbol{k}_j) ,爲了保留Attention類似的分佈特性,咱們要求 sim ( q i , k j ) 0 \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0 恆成立。也就是說,咱們若是要定義新的Attention,必需要保留式(3)的形式,而且知足 sim ( q i , k j ) 0 \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0

這種通常形式的Attention在CV中也被稱爲Non-Local網絡,出自論文《Non-local Neural Networks》

幾個例子

若是直接去掉Softmax,那麼就是 sim ( q i , k j ) = q i k j \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \boldsymbol{q}_i^{\top}\boldsymbol{k}_j ,問題是內積沒法保證非負性,因此這還不是一個合理的選擇。下面咱們介紹幾種可取的方案

值得一提的是,下面介紹的這幾種Linear Attention,前兩種來自CV領域,第三種是蘇劍林大佬構思的(除了下面的介紹外,還有EMANet等CV領域對Attention的改進工做)

核函數形式

一個天然的想法是:若是 q i , k j \boldsymbol{q}_i, \boldsymbol{k}_j 的每一個元素都是非負的,那麼內積天然也是非負的。爲了完成這點,咱們能夠給 q i , k j \boldsymbol{q}_i, \boldsymbol{k}_j 各自加個激活函數 ϕ , φ \phi,\varphi ,即

sim ( q i , k j ) = ϕ ( q i ) φ ( k j ) (4) \begin{aligned}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\tag{4}\end{aligned}

其中 ϕ ( ) , φ ( ) \phi(\cdot), \varphi(\cdot) 是值域非負的激活函數。本文開頭提到的論文《Transformers are RNNs》選擇的是 ϕ ( x ) = φ ( x ) = elu ( x ) + 1 \phi(x)=\varphi(x)=\text{elu}(x)+1 ,其中

elu ( x ) = { x if  x > 0 α ( e x 1 ) if  x < 0 \text{elu}(x)=\begin{cases}x& \text{if} \ x>0\\ \alpha (e^x-1) & \text{if}\ x<0\end{cases}

常見的 α \alpha 取值爲 [ 0.1 , 0.3 ] [0.1, 0.3]

非要講故事的話,式(4)能夠聯想到"核方法",尤爲是 ϕ = φ \phi=\varphi 時, ϕ \phi 就至關於一個核函數,而 ϕ ( q i ) , ϕ ( k j ) \langle \phi(\boldsymbol{q}_i), \phi(\boldsymbol{k}_j)\rangle 就是經過核函數所定義的內積。這方面的思考能夠參考論文《Transformer dissection: An unified understanding for transformer’s attention via the lens of kernel》,此處不作過多延伸

妙用Softmax

另外一篇更早的文章《Efficient Attention: Attention with Linear Complexities》則給出了一個更有意思的選擇。它留意到在 Q K \boldsymbol{QK^{\top}} 中, Q , K R n × d \boldsymbol{Q},\boldsymbol{K}\in \mathbb{R}^{n\times d} ,若是「 Q \boldsymbol{Q} d d 那一維是歸一化的,而且 K \boldsymbol{K} n n 那一維是歸一化的」,那麼 Q K \boldsymbol{QK^{\top}} 就是自動知足歸一化了,因此它給出的選擇是

A t t e n t i o n ( Q , K , V ) = s o f t m a x 2 ( Q ) s o f t m a x 1 ( K ) V (5) \begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax_2\left(\boldsymbol{Q}\right)softmax_1(\boldsymbol{K})^{\top}\boldsymbol{V}\tag{5}\end{aligned}

其中 s o f t m a x 1 softmax_1 s o f t m a x 2 softmax_2 分別表示在第一個 ( n ) (n) 、第二個維度 ( d ) (d) 進行Softmax運算。也就是說,這時候咱們是各自給 Q , K \boldsymbol{Q},\boldsymbol{K} 加Softmax,而不是算完 Q K \boldsymbol{QK^{\top}} 以後再加Softmax

其實能夠證實這個形式也是式(4)​的一個特例,此時對應於 ϕ ( q i ) = s o f t m a x ( q i ) , φ ( k j ) = e k j \phi(\boldsymbol{q}_i)=softmax(\boldsymbol{q}_i),\varphi(\boldsymbol{k}_j)=e^{\boldsymbol{k}_j} ,讀者能夠自行推導一下

蘇神的構思

在這裏,蘇神給出了一種構思。這個構思的出發點再也不是式(4),而是源於咱們對原始定義(2)​的泰勒展開。由泰勒展開咱們有

e q i k j 1 + q i k j (6) \begin{aligned}e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} \approx 1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\tag{6}\end{aligned}

若是 q i k j 1 \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\geq -1 ,那麼就能夠保證右端的非負性,從而可讓 sim ( q i , k j ) = 1 + q i k j \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)=1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j 。到這裏讀者可能已經想到了,想要保證 q i k j 1 \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\geq -1 ,只須要分別對 q i , k j \boldsymbol{q}_i,\boldsymbol{k}_j l 2 l_2 歸一化。因此,蘇神最終提出的方案就是:

sim ( q i , k j ) = 1 + ( q i q i ) ( k j k j ) (7) \begin{aligned}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = 1 + \left( \frac{\boldsymbol{q}_i}{\Vert \boldsymbol{q}_i\Vert}\right)^{\top}\left(\frac{\boldsymbol{k}_j}{\Vert \boldsymbol{k}_j\Vert}\right)\tag{7}\end{aligned}

x = [ x 1 , x 2 , . . . , x n ] \boldsymbol{x}=[x_1,x_2,...,x_n] ,則 x = x 1 2 + x 2 2 + + x n 2 \Vert x\Vert=\sqrt{x_1^2+x_2^2+···+x_n^2}

這不一樣於式(4),但理論上它更加接近原始的Scaled-Dot Attention

實現

這裏主要是針對蘇神所提出的方法進行實現,可是因爲筆者本人水平有限,所以最終實現的代碼當中其實存在一些問題,主要是:

  1. 從測試結果來看,改進後的計算速度並無提高
  2. 沒法作到求和爲1

代碼實現主要是針對BERT的PyTorch實現這篇文章的代碼,更具體的說,其實僅修改了ScaledDotProductAttention這個函數,所以下面只放出這部分代碼

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        Q = F.normalize(Q, dim=3)
        K = F.normalize(K, dim=3)
        M = (torch.ones(Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2]) + torch.matmul(Q, K.transpose(-1, -2))) # scores : [batch_size, n_heads, seq_len, seq_len]
        M_sum = torch.sum(M, dim=3)
        M = M / M_sum.unsqueeze(3).repeat(1, 1, 1, M.shape[3])
        attn = M.masked_fill(attn_mask, 0) # Fills elements of self tensor with value where mask is one.
        context = torch.matmul(attn, V)
        return context
複製代碼

若是您有更好的實現方法,還望不吝賜教

Reference

相關文章
相關標籤/搜索