大概讀一下XLNet源碼,邊讀邊寫,有問題但願和你們交流app
整體的依賴順序就是:函數
最精華且難啃的部分就是modeling.py,其餘的看一下就差不錯了,主要是一塊兒讀一下這個文件,以後其餘的再慢慢加ui
先看一下最主要的函數transformer_xl,代碼太多就不全貼了,挑一些重點的spa
def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt, is_training,
scale):
"""Core relative positional attention operations."""
# content based attention score
ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)
# position based attention score
bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r)
bd = rel_shift(bd, klen=tf.shape(ac)[1])
# segment based attention score
if seg_mat is None:
ef = 0
else:
ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed)
ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef)
# merge attention scores and perform masking
attn_score = (ac + bd + ef) * scale
# more ...複製代碼
if data_mask is not None: # [1, len, bsz] + [len, len, bsz] = [qlen, qlen, bsz]
# all mems can be attended to
mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz],
dtype=tf_float) # [qlen, mlen, bsz]
data_mask = tf.concat([mems_mask, data_mask], 1) # [qlen, mlen+qlen, bsz]
if attn_mask is None:
attn_mask = data_mask[:, :, :, None] # [qlen, mlen+qlen, bsz, 1]
else:
attn_mask += data_mask[:, :, :, None]複製代碼
if attn_mask is not None:
non_tgt_mask = -tf.eye(qlen, dtype=tf_float) # [qlen, qlen]單位矩陣
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float), # [qlen, mlen+qlen]
non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0,
dtype=tf_float) # [qlen, mlen+qlen, 1, 1]
else:
non_tgt_mask = None複製代碼