在NLP模型中,輸入一般是一個句子,例如"I went to New York last week."
,一句話中包含不少單詞(token)。傳統的作法是將這些單詞以空格進行分隔,例如['i', 'went', 'to', 'New', 'York', 'last', 'week']
。然而這種作法存在不少問題,例如模型沒法經過old, older, oldest
之間的關係學到smart, smarter, smartest
之間的關係。若是咱們能使用將一個token分紅多個subtokens,上面的問題就能很好的解決。本文將詳述目前比較經常使用的subtokens算法——BPE(Byte-Pair Encoding)html
如今性能比較好一些的NLP模型,例如GPT、BERT、RoBERTa等,在數據預處理的時候都會有WordPiece的過程,其主要的實現方式就是BPE(Byte-Pair Encoding)。具體來講,例如['loved', 'loving', 'loves']
這三個單詞。其實自己的語義都是"愛"的意思,可是若是咱們以詞爲單位,那它們就算不同的詞,在英語中不一樣後綴的詞很是的多,就會使得詞表變的很大,訓練速度變慢,訓練的效果也不是太好。BPE算法經過訓練,可以把上面的3個單詞拆分紅["lov","ed","ing","es"]
幾部分,這樣能夠把詞的自己的意思和時態分開,有效的減小了詞表的數量。算法流程以下:python
</w>
,同時標記出該單詞出現的次數。例如,"low"
這個單詞出現了5次,那麼它將會被處理爲{'l o w </w>': 5}
例如git
{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}
複製代碼
出現最頻繁的字節對是**e
和s
**,共出現了6+3=9次,所以將它們合併github
{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3}
複製代碼
出現最頻繁的字節對是**es
和t
**,共出現了6+3=9次,所以將它們合併算法
{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3}
複製代碼
出現最頻繁的字節對是**est
和</w>
**,共出現了6+3=9次,所以將它們合併markdown
{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
複製代碼
出現最頻繁的字節對是**l
和o
**,共出現了5+2=7次,所以將它們合併app
{'lo w </w>': 5, 'lo w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
複製代碼
出現最頻繁的字節對是**lo
和w
**,共出現了5+2=7次,所以將它們合併ide
{'low </w>': 5, 'low e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
複製代碼
......繼續迭代直到達到預設的subwords詞表大小或下一個最高頻的字節對出現頻率爲1。這樣咱們就獲得了更加合適的詞表,這個詞表可能會出現一些不是單詞的組合,可是其自己有意義的一種形式oop
中止符</w>
的意義在於表示subword是詞後綴。舉例來講:st
不加</w>
能夠出如今詞首,如st ar
;加了</w>
代表改字詞位於詞尾,如wide st</w>
,兩者意義大相徑庭性能
import re, collections
def get_vocab(filename):
vocab = collections.defaultdict(int)
with open(filename, 'r', encoding='utf-8') as fhand:
for line in fhand:
words = line.strip().split()
for word in words:
vocab[' '.join(list(word)) + ' </w>'] += 1
return vocab
def get_stats(vocab):
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols)-1):
pairs[symbols[i],symbols[i+1]] += freq
return pairs
def merge_vocab(pair, v_in):
v_out = {}
bigram = re.escape(' '.join(pair))
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
for word in v_in:
w_out = p.sub(''.join(pair), word)
v_out[w_out] = v_in[word]
return v_out
def get_tokens(vocab):
tokens = collections.defaultdict(int)
for word, freq in vocab.items():
word_tokens = word.split()
for token in word_tokens:
tokens[token] += freq
return tokens
vocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}
# Get free book from Gutenberg
# wget http://www.gutenberg.org/cache/epub/16457/pg16457.txt
# vocab = get_vocab('pg16457.txt')
print('==========')
print('Tokens Before BPE')
tokens = get_tokens(vocab)
print('Tokens: {}'.format(tokens))
print('Number of tokens: {}'.format(len(tokens)))
print('==========')
num_merges = 5
for i in range(num_merges):
pairs = get_stats(vocab)
if not pairs:
break
best = max(pairs, key=pairs.get)
vocab = merge_vocab(best, vocab)
print('Iter: {}'.format(i))
print('Best pair: {}'.format(best))
tokens = get_tokens(vocab)
print('Tokens: {}'.format(tokens))
print('Number of tokens: {}'.format(len(tokens)))
print('==========')
複製代碼
輸出以下
==========
Tokens Before BPE
Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 16, 'e': 17, 'r': 2, 'n': 6, 's': 9, 't': 9, 'i': 3, 'd': 3})
Number of tokens: 11
==========
Iter: 0
Best pair: ('e', 's')
Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 16, 'e': 8, 'r': 2, 'n': 6, 'es': 9, 't': 9, 'i': 3, 'd': 3})
Number of tokens: 11
==========
Iter: 1
Best pair: ('es', 't')
Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 16, 'e': 8, 'r': 2, 'n': 6, 'est': 9, 'i': 3, 'd': 3})
Number of tokens: 10
==========
Iter: 2
Best pair: ('est', '</w>')
Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 7, 'e': 8, 'r': 2, 'n': 6, 'est</w>': 9, 'i': 3, 'd': 3})
Number of tokens: 10
==========
Iter: 3
Best pair: ('l', 'o')
Tokens: defaultdict(<class 'int'>, {'lo': 7, 'w': 16, '</w>': 7, 'e': 8, 'r': 2, 'n': 6, 'est</w>': 9, 'i': 3, 'd': 3})
Number of tokens: 9
==========
Iter: 4
Best pair: ('lo', 'w')
Tokens: defaultdict(<class 'int'>, {'low': 7, '</w>': 7, 'e': 8, 'r': 2, 'n': 6, 'w': 9, 'est</w>': 9, 'i': 3, 'd': 3})
Number of tokens: 9
==========
複製代碼
在以前的算法中,咱們已經獲得了subword的詞表,對該詞表按照字符個數由多到少排序。編碼時,對於每一個單詞,遍歷排好序的子詞詞表尋找是否有token是當前單詞的子字符串,若是有,則該token是表示單詞的tokens之一
咱們從最長的token迭代到最短的token,嘗試將每一個單詞中的子字符串替換爲token。 最終,咱們將迭代全部tokens,並將全部子字符串替換爲tokens。 若是仍然有子字符串沒被替換但全部token都已迭代完畢,則將剩餘的子詞替換爲特殊token,如<unk>
例如
# 給定單詞序列
["the</w>", "highest</w>", "mountain</w>"]
# 排好序的subword表
# 長度 6 5 4 4 4 4 2
["errrr</w>", "tain</w>", "moun", "est</w>", "high", "the</w>", "a</w>"]
# 迭代結果
"the</w>" -> ["the</w>"]
"highest</w>" -> ["high", "est</w>"]
"mountain</w>" -> ["moun", "tain</w>"]
複製代碼
將全部的tokens拼在一塊兒便可,例如
# 編碼序列
["the</w>", "high", "est</w>", "moun", "tain</w>"]
# 解碼序列
"the</w> highest</w> mountain</w>"
複製代碼
import re, collections
def get_vocab(filename):
vocab = collections.defaultdict(int)
with open(filename, 'r', encoding='utf-8') as fhand:
for line in fhand:
words = line.strip().split()
for word in words:
vocab[' '.join(list(word)) + ' </w>'] += 1
return vocab
def get_stats(vocab):
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols)-1):
pairs[symbols[i],symbols[i+1]] += freq
return pairs
def merge_vocab(pair, v_in):
v_out = {}
bigram = re.escape(' '.join(pair))
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
for word in v_in:
w_out = p.sub(''.join(pair), word)
v_out[w_out] = v_in[word]
return v_out
def get_tokens_from_vocab(vocab):
tokens_frequencies = collections.defaultdict(int)
vocab_tokenization = {}
for word, freq in vocab.items():
word_tokens = word.split()
for token in word_tokens:
tokens_frequencies[token] += freq
vocab_tokenization[''.join(word_tokens)] = word_tokens
return tokens_frequencies, vocab_tokenization
def measure_token_length(token):
if token[-4:] == '</w>':
return len(token[:-4]) + 1
else:
return len(token)
def tokenize_word(string, sorted_tokens, unknown_token='</u>'):
if string == '':
return []
if sorted_tokens == []:
return [unknown_token]
string_tokens = []
for i in range(len(sorted_tokens)):
token = sorted_tokens[i]
token_reg = re.escape(token.replace('.', '[.]'))
matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)]
if len(matched_positions) == 0:
continue
substring_end_positions = [matched_position[0] for matched_position in matched_positions]
substring_start_position = 0
for substring_end_position in substring_end_positions:
substring = string[substring_start_position:substring_end_position]
string_tokens += tokenize_word(string=substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)
string_tokens += [token]
substring_start_position = substring_end_position + len(token)
remaining_substring = string[substring_start_position:]
string_tokens += tokenize_word(string=remaining_substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)
break
return string_tokens
# vocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}
vocab = get_vocab('pg16457.txt')
print('==========')
print('Tokens Before BPE')
tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
print('All tokens: {}'.format(tokens_frequencies.keys()))
print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
print('==========')
num_merges = 10000
for i in range(num_merges):
pairs = get_stats(vocab)
if not pairs:
break
best = max(pairs, key=pairs.get)
vocab = merge_vocab(best, vocab)
print('Iter: {}'.format(i))
print('Best pair: {}'.format(best))
tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
print('All tokens: {}'.format(tokens_frequencies.keys()))
print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
print('==========')
# Let's check how tokenization will be for a known word
word_given_known = 'mountains</w>'
word_given_unknown = 'Ilikeeatingapples!</w>'
sorted_tokens_tuple = sorted(tokens_frequencies.items(), key=lambda item: (measure_token_length(item[0]), item[1]), reverse=True)
sorted_tokens = [token for (token, freq) in sorted_tokens_tuple]
print(sorted_tokens)
word_given = word_given_known
print('Tokenizing word: {}...'.format(word_given))
if word_given in vocab_tokenization:
print('Tokenization of the known word:')
print(vocab_tokenization[word_given])
print('Tokenization treating the known word as unknown:')
print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))
else:
print('Tokenizating of the unknown word:')
print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))
word_given = word_given_unknown
print('Tokenizing word: {}...'.format(word_given))
if word_given in vocab_tokenization:
print('Tokenization of the known word:')
print(vocab_tokenization[word_given])
print('Tokenization treating the known word as unknown:')
print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))
else:
print('Tokenizating of the unknown word:')
print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))
複製代碼
輸出以下
Tokenizing word: mountains</w>...
Tokenization of the known word:
['mountains</w>']
Tokenization treating the known word as unknown:
['mountains</w>']
Tokenizing word: Ilikeeatingapples!</w>...
Tokenizating of the unknown word:
['I', 'like', 'ea', 'ting', 'app', 'l', 'es!</w>']
複製代碼