Sentence-BERT詳解

簡述

BERT和RoBERTa在文本語義類似度(Semantic Textual Similarity)等句子對的迴歸任務上,已經達到了SOTA的結果。可是,它們都須要把兩個句子同時送入網絡,這樣會致使巨大的計算開銷:從10000個句子中找出最類似的句子對,大概須要5000萬( C 10000 2 = 49 , 995 , 000 C_{10000}^2=49,995,000 )個推理計算,在V100GPU上耗時約65個小時。這種結構使得BERT不適合語義類似度搜索,一樣也不適合無監督任務,例如聚類php

解決聚類和語義搜索的一種常見方法是將每一個句子映射到一個向量空間,使得語義類似的句子很接近。一般得到句子向量的方法有兩種:html

  1. 計算全部Token輸出向量的平均值
  2. 使用[CLS]位置輸出的向量

然而,UKP的研究員實驗發現,在文本類似度(STS)任務上,使用上述兩種方法獲得的效果卻並很差,即便是Glove向量也明顯優於樸素的BERT句子embeddings(見下圖前三行)git

Sentence-BERT(SBERT)的做者對預訓練的BERT進行修改:使用**Siamese and Triplet Network(孿生網絡和三胞胎網絡)**生成具備語義的句子Embedding向量。語義相近的句子,其Embedding向量距離就比較近,從而可使用餘弦類似度、曼哈頓距離、歐氏距離等找出語義類似的句子。SBERT在保證準確性的同時,可將上述提到BERT/RoBERTa的65小時下降到5秒(計算餘弦類似度大概0.01秒)。這樣SBERT能夠完成某些新的特定任務,好比聚類、基於語義的信息檢索等github

模型介紹

Pooling策略

SBERT在BERT/RoBERTa的輸出結果上增長了一個Pooling操做,從而生成一個固定維度的句子Embedding。實驗中採起了三種Pooling策略作對比:markdown

  1. CLS:直接用CLS位置的輸出向量做爲整個句子向量
  2. MEAN:計算全部Token輸出向量的平均值做爲整個句子向量
  3. MAX:取出全部Token輸出向量各個維度的最大值做爲整個句子向量

三種策略的實驗對比效果以下網絡

由結果可見,MEAN的效果是最好的,因此後面實驗默認採用的也是MEAN策略app

模型結構

爲了可以fine-tune BERT/RoBERTa,文章採用了孿生網絡和三胞胎網絡來更新參數,以達到生成的句子向量更具語義信息。該網絡結構取決於具體的訓練數據,文中實驗了下面幾種機構和目標函數函數

Classification Objective Function

針對分類問題,做者將向量 u , v , u v u,v,|u-v| 三個向量拼接在一塊兒,而後乘以一個權重參數 W t R 3 n × k W_t\in \mathbb{R}^{3n\times k} ,其中 n n 表示向量的維度, k k 表示label的數量oop

o = s o f t m a x ( W t [ u ; v ; u v ] ) o = softmax(W_t[u;v;|u-v|])

損失函數爲CrossEntropyLoss優化

注:原文公式爲 s o f t m a x ( W t ( u , v , u v ) ) softmax(W_t(u,v,|u-v|)) ,我我的比較喜歡用 [ ; ; ] [;;] 表示向量拼接的意思

Regression Objective Function

兩個句子embedding向量 u , v u,v 的餘弦類似度計算結構以下所示,損失函數爲MAE(mean squared error)

Triplet Objective Function

更多關於Triplet Network的內容能夠看個人這篇Siamese Network & Triplet NetWork。給定一個主句 a a ,一個正面句子 p p 和一個負面句子 n n ,三元組損失調整網絡,使得 a a p p 之間的距離儘量小, a a n n 之間的距離儘量大。數學上,咱們指望最小化如下損失函數:

m a x ( s a s p s a s n + ϵ , 0 ) max(||s_a-s_p||-||s_a-s_n||+\epsilon, 0)

其中, s x s_x 表示句子 x x 的embedding, ||·|| 表示距離,邊緣參數 ϵ \epsilon 表示 s a s_a s p s_p 的距離至少應比 s a s_a s n s_n 的距離近 ϵ \epsilon 。在實驗中,使用歐式距離做爲距離度量, ϵ \epsilon 設置爲1

模型訓練細節

做者訓練時結合了SNLI(Stanford Natural Language Inference)和Multi-Genre NLI兩種數據集。SNLI有570,000我的工標註的句子對,標籤分別爲矛盾,蘊含(eintailment),中立三種;MultiNLI是SNLI的升級版,格式和標籤都同樣,有430,000個句子對,主要是一系列口語和書面語文本

蘊含關係描述的是兩個文本之間的推理關係,其中一個文本做爲前提(Premise),另外一個文本做爲假設(Hypothesis),若是根據前提可以推理得出假設,那麼就說前提蘊含假設。參考樣例以下:

Sentence A (Premise) Sentence B (Hypothesis) Label
A soccer game with multiple males playing. Some men are playing a sport. entailment
An older and younger man smiling. Two men are smiling and laughing at the cats playing on the floor. neutral
A man inspects the uniform of a figure in some East Asian country. The man is sleeping. contradiction

實驗時,做者使用類別爲3的softmax分類目標函數對SBERT進行fine-tune,batch_size=16,Adam優化器,learning_rate=2e-5

消融研究

爲了對SBERT的不一樣方面進行消融研究,以便更好地瞭解它們的相對重要性,咱們在SNLI和Multi-NLI數據集上構建了分類模型,在STS benchmark數據集上構建了迴歸模型。在pooling策略上,對比了MEAN、MAX、CLS三種策略;在分類目標函數中,對比了不一樣的向量組合方式。結果以下

結果代表,Pooling策略影響較小,向量組合策略影響較大,而且 [ u ; v ; u v ] [u;v;|u-v|] 效果最好

Reference

相關文章
相關標籤/搜索