XLNet源碼一塊兒讀

大概讀一下XLNet源碼,邊讀邊寫,有問題但願和你們交流app

1. 概述

1.1 文件結構

  • xxx_utils.py:各類數據預處理、模型加載等輔助函數
  • modeling.py:transformer-xl、tow stream attention實現
  • xlnet.py:更高層的XLNetModel類,封裝modeling的transformer
  • function_builder.py:各類用於pretrain和finetune的loss function
  • train_xxx.py:XLNet預訓練
  • run_xxx.py:XLNet精調和評估

整體的依賴順序就是:函數

  • Pretrain: train_xxx.py -> function_builder.py -> modeling.py -> xxx_utils.py
  • Finetune: run_xxx.py -> function_builder.py -> modeling.py -> xxx_utils.py

最精華且難啃的部分就是modeling.py,其餘的看一下就差不錯了,主要是一塊兒讀一下這個文件,以後其餘的再慢慢加ui

2. 精讀

2.1 modeling.py

先看一下最主要的函數transformer_xl,代碼太多就不全貼了,挑一些重點的spa

  • 輸入參數
    • mems:這個存了前mem_len個batch的信息,estimator每計算一個batch會更新一次,都存在TrainSpec裏
    • perm_mask:[i, j, k]表示在第k個batch,i和j計算attention(0)、不計算(1),由於要加上以前的mems計算,因此會多出k維度和各個batch對齊
    • target_mapping:由於理論上把token都permute了,因此可能先預測4再預測2,因此在預測i=0(第一個4)時要把實際的位置4給mask掉。這裏做者說「in batch k」感受有些不對,這個應該只針對當前的batch,k應該表示的是batch裏的第k個
    • inp_q:沒理解錯的話,1的token至關於BERT的[MASK],若是是None的話就不進行PLM任務
    • untier:是否統一attention計算中的bias。以前BERT對於multi-head的投影都是直接用dense,這裏projection矩陣和bias矩陣是分開的,並且untie_r=False時全部layer的bias都同樣
    • clamp_len:限制relative的長度
  • bias:這裏有三種,論文中稱爲head specific bias vector,我以爲應該是爲了加強擬合能力。有content attention的r_w_bias,position attention的r_r_bias,segment attention的r_s_bias,在rel_attn_core函數中看的比較明白:
def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
                  r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt, is_training,
                  scale):
  """Core relative positional attention operations."""

  # content based attention score
  ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)

  # position based attention score
  bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r)
  bd = rel_shift(bd, klen=tf.shape(ac)[1])

  # segment based attention score
  if seg_mat is None:
    ef = 0
  else:
    ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed)
    ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef)

  # merge attention scores and perform masking
  attn_score = (ac + bd + ef) * scale
  # more ...複製代碼
  • attn_mask:和attention_score保持一致,轉換爲4維
if data_mask is not None: # [1, len, bsz] + [len, len, bsz] = [qlen, qlen, bsz]
      # all mems can be attended to
      mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz],
                           dtype=tf_float) # [qlen, mlen, bsz]
      data_mask = tf.concat([mems_mask, data_mask], 1) # [qlen, mlen+qlen, bsz]
      if attn_mask is None:
        attn_mask = data_mask[:, :, :, None] # [qlen, mlen+qlen, bsz, 1]
      else:
        attn_mask += data_mask[:, :, :, None]複製代碼
  • non_tgt_mask:
if attn_mask is not None:
      non_tgt_mask = -tf.eye(qlen, dtype=tf_float) # [qlen, qlen]單位矩陣
      non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float), # [qlen, mlen+qlen]
                                non_tgt_mask], axis=-1)
      non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0,
                             dtype=tf_float) # [qlen, mlen+qlen, 1, 1]
    else:
      non_tgt_mask = None複製代碼
相關文章
相關標籤/搜索