Transformer 模型的 PyTorch 實現

本文由羅周楊原創,轉載請註明做者和出處。未經受權,不得用於商業用途。html

Google 2017年的論文 Attention is all you need 闡釋了什麼叫作大道至簡!該論文提出了Transformer模型,徹底基於Attention mechanism,拋棄了傳統的RNNCNNpython

咱們根據論文的結構圖,一步一步使用 PyTorch 實現這個Transformer模型。git

Transformer架構

首先看一下transformer的結構圖: github

transformer_architecture

解釋一下這個結構圖。首先,Transformer模型也是使用經典的encoer-decoder架構,由encoder和decoder兩部分組成。網絡

上圖的左半邊用Nx框出來的,就是咱們的encoder的一層。encoder一共有6層這樣的結構。架構

上圖的右半邊用Nx框出來的,就是咱們的decoder的一層。decoder一共有6層這樣的結構。app

輸入序列通過word embeddingpositional encoding相加後,輸入到encoder。函數

輸出序列通過word embeddingpositional encoding相加後,輸入到decoder。學習

最後,decoder輸出的結果,通過一個線性層,而後計算softmax。ui

word embeddingpositional encoding我後面會解釋。咱們首先詳細地分析一下encoder和decoder的每一層是怎麼樣的。

Encoder

encoder由6層相同的層組成,每一層分別由兩部分組成:

  • 第一部分是一個multi-head self-attention mechanism
  • 第二部分是一個position-wise feed-forward network,是一個全鏈接層

兩個部分,都有一個 殘差鏈接(residual connection),而後接着一個Layer Normalization

若是你是一個新手,你可能會問:

  • multi-head self-attention 是什麼呢?
  • 參差結構是什麼呢?
  • Layer Normalization又是什麼?

這些問題咱們在後面會一一解答。

Decoder

和encoder相似,decoder由6個相同的層組成,每個層包括如下3個部分:

  • 第一個部分是multi-head self-attention mechanism
  • 第二部分是multi-head context-attention mechanism
  • 第三部分是一個position-wise feed-forward network

仍是和encoder相似,上面三個部分的每個部分,都有一個殘差鏈接,後接一個Layer Normalization

可是,decoder出現了一個新的東西multi-head context-attention mechanism。這個東西其實也不復雜,理解了multi-head self-attention你就能夠理解multi-head context-attention。這個咱們後面會講解。

Attention機制

在講清楚各類attention以前,咱們得先把attention機制說清楚。

通俗來講,attention是指,對於某個時刻的輸出y,它在輸入x上各個部分的注意力。這個注意力實際上能夠理解爲權重

attention機制也能夠分紅不少種。Attention? Attention! 一問有一張比較全面的表格:

attention_mechanism
Figure 2. a summary table of several popular attention mechanisms.

上面第一種additive attention你可能聽過。之前咱們的seq2seq模型裏面,使用attention機制,這種**加性注意力(additive attention)**用的不少。Google的項目 tensorflow/nmt 裏面使用的attention就是這種。

爲何這種attention叫作additive attention呢?很簡單,對於輸入序列隱狀態h_i和輸出序列的隱狀態s_t,它的處理方式很簡單,直接合併,變成[s_t;h_i]

可是咱們的transformer模型使用的不是這種attention機制,使用的是另外一種,叫作乘性注意力(multiplicative attention)

那麼這種乘性注意力機制是怎麼樣的呢?從上表中的公式也能夠看出來:兩個隱狀態進行點積

Self-attention是什麼?

到這裏就能夠解釋什麼是self-attention了。

上面咱們說attention機制的時候,都會說到兩個隱狀態,分別是h_is_t,前者是輸入序列第i個位置產生的隱狀態,後者是輸出序列在第t個位置產生的隱狀態。

所謂self-attention實際上就是,輸出序列就是輸入序列!所以,計算本身的attention得分,就叫作self-attention

Context-attention是什麼?

知道了self-attention,那你確定猜到了context-attention是什麼了:它是encoder和decoder之間的attention!因此,你也能夠稱之爲encoder-decoder attention!

context-attention一詞並非本人原創,有些文章或者代碼會這樣描述,我以爲挺形象的,因此在此沿用這個稱呼。其餘文章可能會有其餘名稱,可是沒關係,咱們抓住了重點便可,那就是兩個不一樣序列之間的attention,與self-attention相區別。

無論是self-attention仍是context-attention,它們計算attention分數的時候,能夠選擇不少方式,好比上面表中提到的:

  • additive attention
  • local-base
  • general
  • dot-product
  • scaled dot-product

那麼咱們的Transformer模型,採用的是哪一種呢?答案是:scaled dot-product attention

Scaled dot-product attention是什麼?

論文Attention is all you need裏面對於attention機制的描述是這樣的:

An attention function can be described as a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility of the query with the corresponding key.

這句話描述得很清楚了。翻譯過來就是:經過肯定Q和K之間的類似程度來選擇V

用公式來描述更加清晰:

\text{Attention}(Q,K,V)=softmax(\frac{QK^T}{\sqrt d_k})V

scaled dot-product attentiondot-product attention惟一的區別就是,scaled dot-product attention有一個縮放因子\frac{1}{\sqrt d_k}

上面公式中的d_k表示的是K的維度,在論文裏面,默認是64

那麼爲何須要加上這個縮放因子呢?論文裏給出瞭解釋:對於d_k很大的時候,點積獲得的結果維度很大,使得結果處於softmax函數梯度很小的區域。

咱們知道,梯度很小的狀況,這對反向傳播不利。爲了克服這個負面影響,除以一個縮放因子,能夠必定程度上減緩這種狀況。

爲何是\frac{1}{\sqrt d_k}呢?論文沒有進一步說明。我的以爲你可使用其餘縮放因子,看看模型效果有沒有提高。

論文也提供了一張很清晰的結構圖,供你們參考:

scaled_dot_product_attention_arch
Figure 3. Scaled dot-product attention architecture.

首先說明一下咱們的K、Q、V是什麼:

  • 在encoder的self-attention中,Q、K、V都來自同一個地方(相等),他們是上一層encoder的輸出。對於第一層encoder,它們就是word embedding和positional encoding相加獲得的輸入。
  • 在decoder的self-attention中,Q、K、V都來自於同一個地方(相等),它們是上一層decoder的輸出。對於第一層decoder,它們就是word embedding和positional encoding相加獲得的輸入。可是對於decoder,咱們不但願它能得到下一個time step(即未來的信息),所以咱們須要進行sequence masking
  • 在encoder-decoder attention中,Q來自於decoder的上一層的輸出,K和V來自於encoder的輸出,K和V是同樣的。
  • Q、K、V三者的維度同樣,即 d_q=d_k=d_v

上面scaled dot-product attention和decoder的self-attention都出現了masking這樣一個東西。那麼這個mask究竟是什麼呢?這兩處的mask操做是同樣的嗎?這個問題在後面會有詳細解釋。

Scaled dot-product attention的實現

我們先把scaled dot-product attention實現了吧。代碼以下:

import torch
import torch.nn as nn


class ScaledDotProductAttention(nn.Module):
    """Scaled dot-product attention mechanism."""

    def __init__(self, attention_dropout=0.0):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, scale=None, attn_mask=None):
        """前向傳播. Args: q: Queries張量,形狀爲[B, L_q, D_q] k: Keys張量,形狀爲[B, L_k, D_k] v: Values張量,形狀爲[B, L_v, D_v],通常來講就是k scale: 縮放因子,一個浮點標量 attn_mask: Masking張量,形狀爲[B, L_q, L_k] Returns: 上下文張量和attetention張量 """
        attention = torch.bmm(q, k.transpose(1, 2))
        if scale:
        	attention = attention * scale
        if attn_mask:
        	# 給須要mask的地方設置一個負無窮
        	attention = attention.masked_fill_(attn_mask, -np.inf)
		# 計算softmax
        attention = self.softmax(attention)
		# 添加dropout
        attention = self.dropout(attention)
		# 和V作點積
        context = torch.bmm(attention, v)
        return context, attention
複製代碼

Multi-head attention又是什麼呢?

理解了Scaled dot-product attention,Multi-head attention也很簡單了。論文提到,他們發現將Q、K、V經過一個線性映射以後,分紅 h 份,對每一份進行scaled dot-product attention效果更好。而後,把各個部分的結果合併起來,再次通過線性映射,獲得最終的輸出。這就是所謂的multi-head attention。上面的超參數 h 就是heads數量。論文默認是8

下面是multi-head attention的結構圖:

multi-head attention_architecture
Figure 4: Multi-head attention architecture.

值得注意的是,上面所說的分紅 h是在 d_k、d_q、d_v 維度上面進行切分的。所以,進入到scaled dot-product attention的 d_k 實際上等於未進入以前的 D_K/h

Multi-head attention容許模型加入不一樣位置的表示子空間的信息。

Multi-head attention的公式以下:

\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_ 1,\dots,\text{head}_ h)W^O

其中,

\text{head}_ i = \text{Attention}(QW_i^Q,KW_i^K,VW_i^V)

論文裏面,d_{model}=512h=8。因此在scaled dot-product attention裏面的

d_q = d_k = d_v = d_{model}/h = 512/8 = 64

Multi-head attention的實現

相信你們已經理清楚了multi-head attention,那麼咱們來實現它吧。代碼以下:

import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):

    def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
        super(MultiHeadAttention, self).__init__()

        self.dim_per_head = model_dim // num_heads
        self.num_heads = num_heads
        self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)

        self.dot_product_attention = ScaledDotProductAttention(dropout)
        self.linear_final = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
		# multi-head attention以後須要作layer norm
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, key, value, query, attn_mask=None):
		# 殘差鏈接
        residual = query

        dim_per_head = self.dim_per_head
        num_heads = self.num_heads
        batch_size = key.size(0)

        # linear projection
        key = self.linear_k(key)
        value = self.linear_v(value)
        query = self.linear_q(query)

        # split by heads
        key = key.view(batch_size * num_heads, -1, dim_per_head)
        value = value.view(batch_size * num_heads, -1, dim_per_head)
        query = query.view(batch_size * num_heads, -1, dim_per_head)

        if attn_mask:
            attn_mask = attn_mask.repeat(num_heads, 1, 1)
        # scaled dot product attention
        scale = (key.size(-1) // num_heads) ** -0.5
        context, attention = self.dot_product_attention(
          query, key, value, scale, attn_mask)

        # concat heads
        context = context.view(batch_size, -1, dim_per_head * num_heads)

        # final linear projection
        output = self.linear_final(context)

        # dropout
        output = self.dropout(output)

        # add residual and norm layer
        output = self.layer_norm(residual + output)

        return output, attention

複製代碼

上面的代碼終於出現了Residual connectionLayer normalization。咱們如今來解釋它們。

Residual connection是什麼?

殘差鏈接其實很簡單!給你看一張示意圖你就明白了:

residual_conn
Figure 5. Residual connection.

假設網絡中某個層對輸入x做用後的輸出是F(x),那麼增長residual connection以後,就變成了:

F(x)+x

這個+x操做就是一個shortcut

那麼殘差結構有什麼好處呢?顯而易見:由於增長了一項x,那麼該層網絡對x求偏導的時候,多了一個常數項1!因此在反向傳播過程當中,梯度連乘,也不會形成梯度消失

因此,代碼實現residual connection很很是簡單:

def residual(sublayer_fn,x):
	return sublayer_fn(x)+x
複製代碼

文章開始的transformer架構圖中的Add & Norm中的Add也就是指的這個shortcut

至此,residual connection的問題理清楚了。更多關於殘差網絡的介紹能夠看文末的參考文獻。

Layer normalization是什麼?

GRADIENTS, BATCH NORMALIZATION AND LAYER NORMALIZATION一文對normalization有很好的解釋:

Normalization有不少種,可是它們都有一個共同的目的,那就是把輸入轉化成均值爲0方差爲1的數據。咱們在把數據送入激活函數以前進行normalization(歸一化),由於咱們不但願輸入數據落在激活函數的飽和區。

說到normalization,那就確定得提到Batch Normalization。BN在CNN等地方用得不少。

BN的主要思想就是:在每一層的每一批數據上進行歸一化。

咱們可能會對輸入數據進行歸一化,可是通過該網絡層的做用後,咱們的的數據已經再也不是歸一化的了。隨着這種狀況的發展,數據的誤差愈來愈大,個人反向傳播須要考慮到這些大的誤差,這就迫使咱們只能使用較小的學習率來防止梯度消失或者梯度爆炸。

BN的具體作法就是對每一小批數據,在批這個方向上作歸一化。以下圖所示:

batch_normalization
Figure 6. Batch normalization example.(From theneuralperspective.com)

能夠看到,右半邊求均值是沿着數據批量N的方向進行的

Batch normalization的計算公式以下:

BN(x_i)=\alpha\times\frac{x_i-u_B}{\sqrt{\sigma_B^2+\epsilon}}+\beta

具體的實現能夠查看上圖的連接文章。

說完Batch normalization,就該說說我們今天的主角Layer normalization

那麼什麼是Layer normalization呢?:它也是歸一化數據的一種方式,不過LN是在每個樣本上計算均值和方差,而不是BN那種在批方向計算均值和方差

下面是LN的示意圖:

layer_normalization
Figure 7. Layer normalization example.

和上面的BN示意圖一比較就能夠看出兩者的區別啦!

下面看一下LN的公式,也BN十分類似:

LN(x_i)=\alpha\times\frac{x_i-u_L}{\sqrt{\sigma_L^2+\epsilon}}+\beta

Layer normalization的實現

上述兩個參數\alpha\beta都是可學習參數。下面咱們本身來實現Layer normalization(PyTorch已經實現啦!)。代碼以下:

import torch
import torch.nn as nn


class LayerNorm(nn.Module):
    """實現LayerNorm。其實PyTorch已經實現啦,見nn.LayerNorm。"""

    def __init__(self, features, epsilon=1e-6):
        """Init. Args: features: 就是模型的維度。論文默認512 epsilon: 一個很小的數,防止數值計算的除0錯誤 """
        super(LayerNorm, self).__init__()
        # alpha
        self.gamma = nn.Parameter(torch.ones(features))
        # beta
        self.beta = nn.Parameter(torch.zeros(features))
        self.epsilon = epsilon

    def forward(self, x):
        """前向傳播. Args: x: 輸入序列張量,形狀爲[B, L, D] """
        # 根據公式進行歸一化
        # 在X的最後一個維度求均值,最後一個維度就是模型的維度
        mean = x.mean(-1, keepdim=True)
        # 在X的最後一個維度求方差,最後一個維度就是模型的維度
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.epsilon) + self.beta

複製代碼

順便提一句,Layer normalization多用於RNN這種結構。

Mask是什麼?

如今終於輪到講解mask了!mask顧名思義就是掩碼,在咱們這裏的意思大概就是對某些值進行掩蓋,使其不產生效果

須要說明的是,咱們的Transformer模型裏面涉及兩種mask。分別是padding masksequence mask。其中後者咱們已經在decoder的self-attention裏面見過啦!

其中,padding mask在全部的scaled dot-product attention裏面都須要用到,而sequence mask只有在decoder的self-attention裏面用到。

因此,咱們以前ScaledDotProductAttentionforward方法裏面的參數attn_mask在不一樣的地方會有不一樣的含義。這一點咱們會在後面說明。

Padding mask

什麼是padding mask呢?回想一下,咱們的每一個批次輸入序列長度是不同的!也就是說,咱們要對輸入序列進行對齊!具體來講,就是給在較短的序列後面填充0。由於這些填充的位置,實際上是沒什麼意義的,因此咱們的attention機制不該該把注意力放在這些位置上,因此咱們須要進行一些處理。

具體的作法是,把這些位置的值加上一個很是大的負數(能夠是負無窮),這樣的話,通過softmax,這些位置的機率就會接近0

而咱們的padding mask其實是一個張量,每一個值都是一個Boolen,值爲False的地方就是咱們要進行處理的地方。

下面是實現:

def padding_mask(seq_k, seq_q):
	# seq_k和seq_q的形狀都是[B,L]
    len_q = seq_q.size(1)
    # `PAD` is 0
    pad_mask = seq_k.eq(0)
    pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1)  # shape [B, L_q, L_k]
    return pad_mask
複製代碼

Sequence mask

文章前面也提到,sequence mask是爲了使得decoder不能看見將來的信息。也就是對於一個序列,在time_step爲t的時刻,咱們的解碼輸出應該只能依賴於t時刻以前的輸出,而不能依賴t以後的輸出。所以咱們須要想一個辦法,把t以後的信息給隱藏起來。

那麼具體怎麼作呢?也很簡單:產生一個上三角矩陣,上三角的值全爲1,下三角的值權威0,對角線也是0。把這個矩陣做用在每個序列上,就能夠達到咱們的目的啦。

具體的代碼實現以下:

def sequence_mask(seq):
    batch_size, seq_len = seq.size()
    mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),
                    diagonal=1)
    mask = mask.unsqueeze(0).expand(batch_size, -1, -1)  # [B, L, L]
    return mask
複製代碼

哈佛大學的文章The Annotated Transformer有一張效果圖:

sequence_mask
Figure 8. Sequence mask.

值得注意的是,原本mask只須要二維的矩陣便可,可是考慮到咱們的輸入序列都是批量的,因此咱們要把本來二維的矩陣擴張成3維的張量。上面的代碼能夠看出,咱們已經進行了處理。

回到本小結開始的問題,attn_mask參數有幾種狀況?分別是什麼意思?

  • 對於decoder的self-attention,裏面使用到的scaled dot-product attention,同時須要padding masksequence mask做爲attn_mask,具體實現就是兩個mask相加做爲attn_mask。
  • 其餘狀況,attn_mask一概等於padding mask

至此,mask相關的問題解決了。

Positional encoding是什麼?

好了,終於要解釋位置編碼了,那就是文字開始的結構圖提到的Positional encoding

就目前而言,咱們的Transformer架構彷佛少了點什麼東西。沒錯,就是它對序列的順序沒有約束!咱們知道序列的順序是一個很重要的信息,若是缺失了這個信息,可能咱們的結果就是:全部詞語都對了,可是沒法組成有意義的語句!

爲了解決這個問題。論文提出了Positional encoding。這是啥?一句話歸納就是:對序列中的詞語出現的位置進行編碼!若是對位置進行編碼,那麼咱們的模型就能夠捕捉順序信息!

那麼具體怎麼作呢?論文的實現頗有意思,使用正餘弦函數。公式以下:

PE(pos,2i) = sin(pos/10000^{2i/d_{model}})
PE(pos,2i+1) = cos(pos/10000^{2i/d_{model}})

其中,pos是指詞語在序列中的位置。能夠看出,在偶數位置,使用正弦編碼,在奇數位置,使用餘弦編碼

上面公式中的d_{model}是模型的維度,論文默認是512

這個編碼公式的意思就是:給定詞語的位置\text{pos},咱們能夠把它編碼成d_{model}維的向量!也就是說,位置編碼的每個維度對應正弦曲線,波長構成了從2\pi10000*2\pi的等比序列。

上面的位置編碼是絕對位置編碼。可是詞語的相對位置也很是重要。這就是論文爲何要使用三角函數的緣由!

正弦函數可以表達相對位置信息。,主要數學依據是如下兩個公式:

sin(\alpha+\beta) = sin\alpha cos\beta + cos\alpha sin\beta
cos(\alpha+\beta) = cos\alpha cos\beta - sin\alpha sin\beta

上面的公式說明,對於詞彙之間的位置偏移kPE(pos+k)能夠表示成PE(pos)PE(k)的組合形式,這就是表達相對位置的能力!

以上就是PE的全部祕密。說完了positional encoding,那麼咱們還有一個與之處於同一地位的word embedding

Word embedding你們都很熟悉了,它是對序列中的詞彙的編碼,把每個詞彙編碼成d_{model}維的向量!看到沒有,Postional encoding是對詞彙的位置編碼,word embedding是對詞彙自己編碼

因此,我更喜歡positional encoding的另一個名字Positional embedding

Positional encoding的實現

PE的實現也不難,按照論文的公式便可。代碼以下:

import torch
import torch.nn as nn


class PositionalEncoding(nn.Module):
    
    def __init__(self, d_model, max_seq_len):
        """初始化。 Args: d_model: 一個標量。模型的維度,論文默認是512 max_seq_len: 一個標量。文本序列的最大長度 """
        super(PositionalEncoding, self).__init__()
        
        # 根據論文給的公式,構造出PE矩陣
        position_encoding = np.array([
          [pos / np.pow(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]
          for pos in range(max_seq_len)])
        # 偶數列使用sin,奇數列使用cos
        position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])
        position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])

        # 在PE矩陣的第一行,加上一行全是0的向量,表明這`PAD`的positional encoding
        # 在word embedding中也常常會加上`UNK`,表明位置單詞的word embedding,二者十分相似
        # 那麼爲何須要這個額外的PAD的編碼呢?很簡單,由於文本序列的長度不一,咱們須要對齊,
        # 短的序列咱們使用0在結尾補全,咱們也須要這些補全位置的編碼,也就是`PAD`對應的位置編碼
        pad_row = torch.zeros([1, d_model])
        position_encoding = torch.cat((pad_row, position_encoding))
        
        # 嵌入操做,+1是由於增長了`PAD`這個補全位置的編碼,
        # Word embedding中若是詞典增長`UNK`,咱們也須要+1。看吧,二者十分類似
        self.position_encoding = nn.Embedding(max_seq_len + 1, d_model)
        self.position_encoding.weight = nn.Parameter(position_encoding,
                                                     requires_grad=False)
    def forward(self, input_len):
        """神經網絡的前向傳播。 Args: input_len: 一個張量,形狀爲[BATCH_SIZE, 1]。每個張量的值表明這一批文本序列中對應的長度。 Returns: 返回這一批序列的位置編碼,進行了對齊。 """
        
        # 找出這一批序列的最大長度
        max_len = torch.max(input_len)
        tensor = torch.cuda.LongTensor if input_len.is_cuda else torch.LongTensor
        # 對每個序列的位置進行對齊,在原序列位置的後面補上0
        # 這裏range從1開始也是由於要避開PAD(0)的位置
        input_pos = tensor(
          [list(range(1, len + 1)) + [0] * (max_len - len) for len in input_len])
        return self.position_encoding(input_pos)
    
複製代碼

Word embedding的實現

Word embedding應該是老生常談了,它實際上就是一個二維浮點矩陣,裏面的權重是可訓練參數,咱們只須要把這個矩陣構建出來就完成了word embedding的工做。

因此,具體的實現很簡單:

import torch.nn as nn


embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
# 得到輸入的詞嵌入編碼
seq_embedding = seq_embedding(inputs)*np.sqrt(d_model)
複製代碼

上面vocab_size就是詞典的大小,embedding_size就是詞嵌入的維度大小,論文裏面就是等於d_{model}=512。因此word embedding矩陣就是一個vocab_size*embedding_size的二維張量。

若是你想獲取更詳細的關於word embedding的信息,能夠看個人另一個文章word2vec的筆記和實現

Position-wise Feed-Forward network是什麼?

這就是一個全鏈接網絡,包含兩個線性變換和一個非線性函數(實際上就是ReLU)。公式以下:

FFN(x)=max(0,xW_1+b_1)W_2+b_2

這個線性變換在不一樣的位置都表現地同樣,而且在不一樣的層之間使用不一樣的參數。

論文提到,這個公式還能夠用兩個核大小爲1的一維卷積來解釋,卷積的輸入輸出都是d_{model}=512,中間層的維度是d_{ff}=2048

實現以下:

import torch
import torch.nn as nn


class PositionalWiseFeedForward(nn.Module):

    def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
        super(PositionalWiseFeedForward, self).__init__()
        self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
        self.w2 = nn.Conv1d(model_dim, ffn_dim, 1)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, x):
        output = x.transpose(1, 2)
        output = self.w2(F.relu(self.w1(output)))
        output = self.dropout(output.transpose(1, 2))

        # add residual and norm layer
        output = self.layer_norm(x + output)
        return output
複製代碼

Transformer的實現

至此,全部的細節都已經解釋完了。如今來完成咱們Transformer模型的代碼。

首先,咱們須要實現6層的encoder和decoder。

encoder代碼實現以下:

import torch
import torch.nn as nn


class EncoderLayer(nn.Module):
	"""Encoder的一層。"""

    def __init__(self, model_dim=512, num_heads=8, ffn_dim=2018, dropout=0.0):
        super(EncoderLayer, self).__init__()

        self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
        self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)

    def forward(self, inputs, attn_mask=None):

        # self attention
        context, attention = self.attention(inputs, inputs, inputs, padding_mask)

        # feed forward network
        output = self.feed_forward(context)

        return output, attention


class Encoder(nn.Module):
	"""多層EncoderLayer組成Encoder。"""

    def __init__(self, vocab_size, max_seq_len, num_layers=6, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
        super(Encoder, self).__init__()

        self.encoder_layers = nn.ModuleList(
          [EncoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in
           range(num_layers)])

        self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
        self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)

    def forward(self, inputs, inputs_len):
        output = self.seq_embedding(inputs)
        output += self.pos_embedding(inputs_len)

        self_attention_mask = padding_mask(inputs, inputs)

        attentions = []
        for encoder in self.encoder_layers:
            output, attention = encoder(output, self_attention_mask)
            attentions.append(attention)

        return output, attentions

複製代碼

經過文章前面的分析,代碼不須要更多解釋了。一樣的,咱們的decoder代碼以下:

import torch
import torch.nn as nn


class DecoderLayer(nn.Module):

    def __init__(self, model_dim, num_heads=8, ffn_dim=2048, dropout=0.0):
        super(DecoderLayer, self).__init__()

        self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
        self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)

    def forward(self, dec_inputs, enc_outputs, self_attn_mask=None, context_attn_mask=None):
        # self attention, all inputs are decoder inputs
        dec_output, self_attention = self.attention(
          dec_inputs, dec_inputs, dec_inputs, self_attn_mask)

        # context attention
        # query is decoder's outputs, key and value are encoder's inputs
        dec_output, context_attention = self.attention(
          enc_outputs, enc_outputs, dec_output, context_attn_mask)

        # decoder's output, or context
        dec_output = self.feed_forward(dec_output)

        return dec_output, self_attention, context_attention


class Decoder(nn.Module):

    def __init__(self, vocab_size, max_seq_len, num_layers=6, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
        super(Decoder, self).__init__()

        self.num_layers = num_layers

        self.decoder_layers = nn.ModuleList(
          [DecoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in
           range(num_layers)])

        self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
        self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)

    def forward(self, inputs, inputs_len, enc_output, context_attn_mask=None):
        output = self.seq_embedding(inputs)
        output += self.pos_embedding(inputs_len)

        self_attention_padding_mask = padding_mask(inputs, inputs)
        seq_mask = sequence_mask(inputs)
        self_attn_mask = torch.gt((self_attention_padding_mask + seq_mask), 0)

        self_attentions = []
        context_attentions = []
        for decoder in self.decoder_layers:
            output, self_attn, context_attn = decoder(
            output, enc_output, self_attn_mask, context_attn_mask)
            self_attentions.append(self_attn)
            context_attentions.append(context_attn)

        return output, self_attentions, context_attentions
複製代碼

最後,咱們把encoder和decoder組成Transformer模型!

代碼以下:

import torch
import torch.nn as nn


class Transformer(nn.Module):

    def __init__(self, src_vocab_size, src_max_len, tgt_vocab_size, tgt_max_len, num_layers=6, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.2):
        super(Transformer, self).__init__()

        self.encoder = Encoder(src_vocab_size, src_max_len, num_layers, model_dim,
                               num_heads, ffn_dim, dropout)
        self.decoder = Decoder(tgt_vocab_size, tgt_max_len, num_layers, model_dim,
                               num_heads, ffn_dim, dropout)

        self.linear = nn.Linear(model_dim, tgt_vocab_size, bias=False)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, src_seq, src_len, tgt_seq, tgt_len):
        context_attn_mask = padding_mask(tgt_seq, src_seq)

        output, enc_self_attn = self.encoder(src_seq, src_len)

        output, dec_self_attn, ctx_attn = self.decoder(
          tgt_seq, tgt_len, output, context_attn_mask)

        output = self.linear(output)
        output = self.softmax(output)

        return output, enc_self_attn, dec_self_attn, ctx_attn

複製代碼

至此,Transformer模型已經實現了!

參考文章

1.爲何ResNet和DenseNet能夠這麼深?一文詳解殘差塊爲什麼有助於解決梯度彌散問題
2.GRADIENTS, BATCH NORMALIZATION AND LAYER NORMALIZATION
3.The Annotated Transformer
4.Building the Mighty Transformer for Sequence Tagging in PyTorch : Part I
5.Building the Mighty Transformer for Sequence Tagging in PyTorch : Part II
6.Attention?Attention!

參考代碼

1.jadore801120/attention-is-all-you-need-pytorch
2.JayParks/transformer

聯繫我

相關文章
相關標籤/搜索