Scaled Dot-Product Attention是transformer的encoder的multi-head attention的組成部分。3d
因爲Scaled Dot-Product Attention是multi-head的構成部分,所以Scaled Dot-Product Attention的數據的輸入q,k,v的shape一般咱們會變化爲以下:code
(batch, n_head, seqLen, dim) 其中n_head表示multi-head的個數,且n_head*dim = embedSizeorm
整個輸入到輸出,數據的維度保持不變。blog
temperature表示Scaled,即dim**0.5io
mask表示每一個batch對應樣本中若是sequence爲pad,則對應的mask爲False,所以mask的初始維度爲(batchSize, seqLen),爲了計算,mask的維度會擴充爲(batchSize, 1, 1, seqLen)。form
class ScaledDotProductAttention(nn.Module): """ Compute 'Scaled Dot Product Attention """ def forward(self, query, key, value, mask=None, dropout=None): # (batch, n_head, seq_len, dim) scores = torch.matmul(query, key.transpose(-2, -1))/np.sqrt(query.size(-1)) # (batch, n_head, seq_len_q, seq_len_v) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = F.softmax(scores, dim=-1) if dropout is not None: p_attn = dropout(p_attn) # (batch, n_head, seq_len_q, dim) return torch.matmul(p_attn, value), p_attn
注意:class
當QKV來自同一個向量的矩陣變換時稱做self-attention;transform
當Q和KV來自不一樣的向量的矩陣變換時叫soft-attention;im