實現nlp文本生成中的beam search解碼器

天然語言處理任務,好比caption generation(圖片描述文本生成)、機器翻譯中,都須要進行詞或者字符序列的生成。常見於seq2seq模型或者RNNLM模型中。html

這篇博文主要介紹文本生成解碼過程當中用的greedy search 和beam search算法實現。其中,greedy search 比較簡單,着重介紹beam search算法的實現。算法

 

 咱們在文本生成解碼時,其實是想找對最有的文本序列,或者說是機率,可能性最大的文本序列。而要在全局搜索這個最有解空間,每每是不可能的(由於詞典太大),建設生成序列長度爲N,詞典大小爲V, 則複雜度爲 V^N次方。這其實是一個NP難題。退而求其次,咱們使用啓發式算法,來找到可能的最優解,或者說足夠好的解。spring

 

假設序列數據(假設每一個位置詞的機率都已經給出):app

data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)

 

一、greedy search decoderspa

很是簡單,咱們用argmax就能夠實現翻譯

# greedy decoder
def greedy_decoder(data):
    # 每一行最大機率詞的索引
    return [argmax(s) for s in data]

完整代碼code

from numpy import array
from numpy import argmax

# greedy decoder
def greedy_decoder(data):
    # 每一行最大機率詞的索引
    return [argmax(s) for s in data]

# 定義一個句子,長度爲10,詞典大小爲5
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# 使用greedy search解碼
result = greedy_decoder(data)
print(result)

 

2. beam searchhtm

與greedy search不一樣,beam search返回多個最有可能的解碼結果(具體多少個,由參數k執行)。blog

greedy search每一步都都採用最大機率的詞,而beam search每一步都保留k個最有可能的結果,在每一步,基於以前的k個可能最優結果,繼續搜索下一步。(參考下面示意圖理解)排序

 

示例圖(設置返回解碼結果爲2個):

 

from math import log
from numpy import array
from numpy import argmax

# beam search
def beam_search_decoder(data, k):
    sequences = [[list(), 1.0]]
    for row in data:
        all_candidates = list()
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score * -log(row[j])]
                all_candidates.append(candidate)
        # 全部候選根據分值排序
        ordered = sorted(all_candidates, key=lambda tup:tup[1])
        # 選擇前k個
        sequences = ordered[:k]
    return sequences

# 定義一個句子,長度爲10,詞典大小爲5
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# 解碼
result = beam_search_decoder(data, 3)
# print result
for seq in result:
    print(seq)

 

 相關資料:

相關文章
相關標籤/搜索