Scaled Dot-Product Attention(transformer)

Scaled Dot-Product Attention是transformer的encoder的multi-head attention的組成部分。3d

因爲Scaled Dot-Product Attention是multi-head的構成部分,所以Scaled Dot-Product Attention的數據的輸入q,k,v的shape一般咱們會變化爲以下:code

(batch, n_head, seqLen, dim)  其中n_head表示multi-head的個數,且n_head*dim = embedSizeorm

整個輸入到輸出,數據的維度保持不變。blog

temperature表示Scaled,即dim**0.5io

mask表示每一個batch對應樣本中若是sequence爲pad,則對應的mask爲False,所以mask的初始維度爲(batchSize, seqLen),爲了計算,mask的維度會擴充爲(batchSize, 1, 1, seqLen)。form

class ScaledDotProductAttention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention
    """
    def forward(self, query, key, value, mask=None, dropout=None):  # (batch, n_head, seq_len, dim)
        scores = torch.matmul(query, key.transpose(-2, -1))/np.sqrt(query.size(-1))  # (batch, n_head, seq_len_q, seq_len_v)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        # (batch, n_head, seq_len_q, dim)
        return torch.matmul(p_attn, value), p_attn

注意:class

當QKV來自同一個向量的矩陣變換時稱做self-attention;transform

當Q和KV來自不一樣的向量的矩陣變換時叫soft-attention;im

相關文章
相關標籤/搜索