skip-gram的tensorflow實現

word2vec模型有兩種形式,skip-gram和cbow。skip-gram根據中心詞(target)預測上下文(context),而cbow根據上下文(context)預測中心詞(target)。
本文主要介紹使用tensorflow實現基於負採樣(negative sampling)的skip-gram模型。主要代碼以下,html

def build_model(BATCH_SIZE, VOCAB_SIZE, EMBED_SIZE, NUM_SAMPLED, NUM_TRUE=1):
    '''
    Build the model (i.e. computational graph) and return the placeholders (input and output) and the loss 
    '''
    with tf.name_scope('data'):
        target_node = tf.placeholder(tf.int32, shape=[BATCH_SIZE], name='target_node')
        context_node = tf.placeholder(tf.int32, shape=[BATCH_SIZE, NUM_TRUE], name='context_node')
        negative_samples = (tf.placeholder(tf.int32, shape=[NUM_SAMPLED], name='negative_samples'),
            tf.placeholder(tf.float32, shape=[BATCH_SIZE, NUM_TRUE], name='true_expected_count'),
            tf.placeholder(tf.float32, shape=[NUM_SAMPLED], name='sampled_expected_count'))
    with tf.name_scope('target_embedding_matrix'):
        target_embed_matrix = tf.Variable(tf.random_uniform([VOCAB_SIZE, EMBED_SIZE], -1.0, 1.0), 
                            name='target_embed_matrix')
    # define the inference
    with tf.name_scope('loss'):
        target_embed = tf.nn.embedding_lookup(target_embed_matrix, target_node, name='embed')
        # nce_weight: context_embed
        nce_weight = tf.Variable(tf.truncated_normal([VOCAB_SIZE, EMBED_SIZE],
                                                    stddev=1.0 / (EMBED_SIZE ** 0.5)), 
                                                    name='nce_weight')
        nce_bias = tf.Variable(tf.zeros([VOCAB_SIZE]), name='nce_bias')
        loss = tf.reduce_mean(tf.nn.nce_loss(weights=nce_weight, 
                                            biases=nce_bias, 
                                            labels=context_node, 
                                            inputs=target_embed,
                                            sampled_values = negative_samples, 
                                            num_sampled=NUM_SAMPLED, 
                                            num_classes=VOCAB_SIZE), name='loss')

        loss_summary = tf.summary.scalar("loss_summary", loss)

    return target_node, context_node, negative_samples, loss

tf.nn.nce_loss

tf.nn.nce_loss用來計算和返回NCE loss(noise-contrastive estimation training loss),咱們藉助其實現skip-gram模型中的負採樣。
$$p(w_O|w_I)=\log\sigma({v^{'}}_{w_O}^\top v_{w_I})+\sum_{i=1}^k \mathbb{E}_{w_i\sim P_n(w)}[\log \sigma(-{v^{'}}_{w_i}^\top v_{w_I})]$$node

def nce_loss(weights,
             biases,
             labels,
             inputs,
             num_sampled,
             num_classes,
             num_true=1,
             sampled_values=None,
             remove_accidental_hits=False,
             partition_strategy="mod",
             name="nce_loss"):
  logits, labels = _compute_sampled_logits(
      weights=weights,
      biases=biases,
      labels=labels,
      inputs=inputs,
      num_sampled=num_sampled,
      num_classes=num_classes,
      num_true=num_true,
      sampled_values=sampled_values,
      subtract_log_q=True,
      remove_accidental_hits=remove_accidental_hits,
      partition_strategy=partition_strategy,
      name=name)
  sampled_losses = sigmoid_cross_entropy_with_logits(
      labels=labels, logits=logits, name="sampled_losses")
  return _sum_rows(sampled_losses)

其中,weights爲一個tensor(shape爲[num_classes, dim])或者由tensor組成的list(list進行concatenate操做後獲得的shape爲[num_classes, dim]),num_classes爲類別總數,在word2vec中對應詞彙表中單詞的總個數,dim爲embedding的維度,weights對應公式中的${v^{'}}_w$,爲context embeddings;python

biases是一個shape爲[num_classes]的tensor,對應偏置;git

labels是shape爲[batch_size, num_true]的tensor,num_true爲每個訓練樣本中正樣本(context)的個數,默認爲1,在word2vec中,固定爲1,labels表示target word對應的context word在詞彙表中的index;github

inputs是[batch_size, dim]的tensor,對應公式中的$v_w$,表示target embeddings;segmentfault

num_sampled,int類型,表示每個訓練樣本(即每個batch)中負採樣的個數;app

sampled_values爲3元組(sampled_candidates,true_expected_count,sampled_expected_count),若是sampled_values=None,則默認使用tf.nn.log_uniform_canidate_sampler返回的3元組。dom

candidator_sampler

def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
                                  range_max, seed=None, name=None):
  seed1, seed2 = random_seed.get_seed(seed)
  return gen_candidate_sampling_ops.log_uniform_candidate_sampler(
      true_classes, num_true, num_sampled, unique, range_max, seed=seed1,
      seed2=seed2, name=name)

其中,true_classes的shape爲[batch_size, num_true],在word2vec中即爲context words在詞彙表中的序號,num_true=1;ide

num_sampled,int類型,表示隨機採樣的負樣本個數;post

unique,bool類型,每個batch是否採樣不放回採樣(爲true則一個batch中的全部類別不相同);

range_max,int類型,表示類別總數,對應word2vec的詞彙表中單詞總個數;

返回的3元組(sampled_candidates,true_expected_count,sampled_expected_count)中,sampled_candidates的shape爲[num_sampled],表示num_sampled個負採樣的index(在word2vec中即爲單詞在詞彙表中的序號),true_expected_count的shape與true_classes相同,表示每個正樣本在採樣分佈下的預期計數,sampled_expected_count的shape與sampled_candidates相同,表示每個負採樣樣本在採樣分佈下的預期計數。

須要注意的是,若是使用默認的tf.nn.log_uniform_candidate_sampler進行採樣,詞彙表(vocabulary)中的單詞應該是按出現的頻率從高到低排列(出現頻率高的單詞對應weightsinputs中靠前的embedding),這是由於默認的機率分佈爲
$$P(class_i) = \frac{(\log(class_i + 2) - \log(class_i + 1))}{ \log(range\_max + 1)}$$

案例
context = tf.placeholder(tf.int64, [5, 1], name="true_classes")
#   If `unique=True`, then these are post-rejection probabilities and we
#   compute them approximately.
(sampled_candidates, true_expected_count, sampled_expected_count) = tf.nn.log_uniform_candidate_sampler(
    true_classes=context,
    num_true=1,
    num_sampled=4,
    unique=False,
    range_max=10,
    seed=1234
)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    t1, t2, t3 = sess.run([sampled_candidates,true_expected_count,sampled_expected_count],
                          feed_dict={context: [[0], [0], [1], [2], [1]]})
    print(t1)
    print(t2)
    print(t3)
    
output:
[5 1 1 0]
[[ 1.1562593 ]
 [ 1.1562593 ]
 [ 0.67636836]
 [ 0.47989097]
 [ 0.67636836]]
[ 0.25714332  0.67636836  0.67636836  1.1562593 ]

_compute_sampled_logits

在tf.nn.nce_loss的內部調用了_compute_sampled_logits,返回值logits的shape爲[batch_size, num_true + num_sampled],即[batch_size, 1 + num_sampled],每個batch中的值分別對應公式中的${v^{'}}_{w_O}^\top v_{w_I}$和${v^{'}}_{w_i}^\top v_{w_I}$,
返回值labels的shape與logits相同,每個batch爲num_true個1和num_sampled個0。
對於返回值logits和labels調用sigmoid_cross_entropy_with_logits計算loss。

sigmoid_cross_entropy_with_logits

令x=logits,z=labels,使用sigmoid_cross_entropy_with_logits計算獲得的loss爲
$$loss=z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))$$
恰好就是公式中的$-p(w_O|w_I)$,所以最大化$p(w_O|w_I)$,只需最小化loss。

參考

  1. https://www.cnblogs.com/xiaoj...
  2. https://github.com/apple2373/...
相關文章
相關標籤/搜索