一種基於均值不等式的Listwise損失函數

一種基於均值不等式的Listwise損失函數

1 前言

1.1 Learning to Rank 簡介

Learning to Rank (LTR) , 也被叫作排序學習, 是搜索中的重要技術, 其目的是根據候選文檔和查詢語句的相關性對候選文檔進行排序, 或者選取topk文檔. 好比在搜索引擎中, 須要根據用戶問題選取最相關的搜索結果展現到首頁. 下圖是搜索引擎的搜索結果
search_result.jpghtml

1.2 LTR算法分類

根據損失函數可把LTR分爲三種:python

  1. Pointwise, 該類型算法將LTR任務做爲迴歸任務來訓練, 即嘗試訓練一個爲文檔和查詢語句的打分器, 而後根據打分進行排序.
  2. Pairwise, 該類型算法的損失函數考慮了兩個候選文檔, 學習目標是把相關性高的文檔排在前面, triplet loss 就屬於Pairwise, 它的損失函數是

\[loss = max(0, score_{neg}-score_{pos}+margin) \]

能夠看出該損失函數一次考慮兩個候選文檔.
3. Listwise, 該類型算法的損失函數會考慮多個候選文檔, 這是本文的重點, 下面會詳細介紹.算法

1.3 本文主要內容

本文主要介紹了本人在學習研究過程當中發明的一種新的Listwise損失函數, 以及該損失函數的使用效果. 若是讀者對LTR任務及其算法還不夠熟悉, 建議先去學習LTR相關知識, 同時本人博文天然語言處理中的負樣本挖掘 (分類與排序任務中如何選擇負樣本) 也和本文關係較大, 能夠先進行閱讀.網絡

2 預備知識

2.1 數學符號定義

\(q\)表明用戶搜索問題, 好比"如何成爲宇航員", \(D\)表明候選文檔集合,\(d^+\)表明和\(q\)相關的文檔,\(d^-\)表明和\(q\)不相關的文檔, \(d^+_i\)表明第\(i\)個和\(q\)相關的文檔, LTR的目標就是根據\(q\)找到最相關的文檔\(d\)函數

2.2 學習目標

本次學習目標是訓練一個打分器 scorer, 它能夠衡量q和d的相關性, \(scorer(q, d)\)就是相關性分數,分值越大越相關. 當前主流方法下, scorer通常選用深度神經網絡模型.學習

2.3訓練數據分類

損失函數不一樣, 構造訓練數據的方法也會不一樣:優化

-Pointwise, 能夠構造迴歸數據集, 相關的數據設爲1, 不相關設爲0.
-Pairwise, 可構造triplet類型的數據集, 形如(\(q,d^+, d^-\))
-Listwise, 可構造這種類型的訓練集: (\(q,d^+_1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-_{n+m}\)), 一個正例仍是多個正例也會影響到損失函數的構造, 本文提出的損失函數是針對多正例多負例的狀況.搜索引擎

3 基於均值不等式的Listwise損失函數

3.1 損失函數推導過程

在上一小結咱們能夠知道,訓練集是以下形式 (\(q,d^+_1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-_{n+m}\)), 對於一個q, 有n個相關的文檔和m個不相關的文檔, 那麼咱們一共能夠獲取m+n個分值:\((score_1,score_2,...,score_n,...,score_{n+m})\), 咱們但願打分器對相關文檔打分趨近於正無窮, 對不相關文檔打分趨近於負無窮.spa

對m+n個分值作一個softmax獲得\(p_1,p_2,...,p_n,...,p_{n+m}\), 此時\(p_i\)能夠看做是第i個候選文檔與q相關的機率, 顯然咱們但願\(p_1,p_2,...,p_n\)越大越好, \(p_{n+1},...,p_{m+n}\)越小越好, 即趨近於0. 所以咱們暫時的優化目標是\(\sum_{i=1}^{n}{p_i} \rightarrow 1\).code

可是這個優化目標是不合理的, 假設\(p_1=1\), 其餘值全爲0, 雖然知足了上面的要求, 但這並非咱們想要的. 由於咱們不只但願\(\sum_{i=1}^{n}{p_i} \rightarrow 1\), 還但願相關候選文檔的每個p值都要足夠大, 即咱們但願: n個候選文檔都與q相關的機率是最大的, 因此咱們真正的優化目標是:

\[\max(\prod_{i=1}^{n}{p_i} ) , \sum_{i=1}^{n}{p_i} = 1 \]

當前狀況下, 損失函數已經能夠經過代碼實現了, 可是咱們還能夠作一些化簡工做, \(\prod_{i=1}^{n}{p_i}\)是存在最大值的, 根據均值不等式可得:

\[\prod_{i=1}^{n}{p_i} \leq (\frac{\sum_{i=1}^{n}{p_i}}{n})^n \]

對兩邊取對數:

\[\sum_{i=1}^{n}{log(p_i)} \leq -nlog(n) \]

這樣是否是感受清爽多了, 而後咱們把它轉換成損失函數的形式:

\[loss = -nlog(n) - \sum_{i=1}^{n}{log(p_i)} \]

因此咱們的訓練目標就是\(\min{(loss)}\)

3.2 使用pytorch實現該損失函數

在獲取到最終的損失函數後, 咱們還須要用代碼來實現, 實現代碼以下:

# A simple example for my listwise loss function
# Assuming that n=3, m=4
# In[1]
# scores
scores = torch.tensor([[3,4.3,5.3,0.5,0.25,0.25,1]])
print(scores)
print(scores.shape)
'''
tensor([[0.3000, 0.3000, 0.3000, 0.0250, 0.0250, 0.0250, 0.0250]])
torch.Size([1, 7])
'''
# In[2]
# log softmax
log_prob = torch.nn.functional.log_softmax(scores,dim=1)
print(log_prob)
'''
tensor([[-2.7073, -1.4073, -0.4073, -5.2073, -5.4573, -5.4573, -4.7073]])
'''
# In[3]
# compute loss
n = 3.
mask = torch.tensor([[1,1,1,0,0,0,0]]) # number of 1 is n
loss = -1*n*torch.log(torch.tensor([[n]])) - torch.sum(log_prob*mask,dim=1,keepdim=True)
print(loss)
loss = loss.mean()
print(loss)
'''
tensor([[1.2261]])
tensor(1.2261)
'''

該示例代碼僅展示了batch_size爲1的狀況, 在batch_size大於1時, 每一條數據都有不一樣的m和n, 爲了能一塊兒送入模型計算分值, 須要靈活的使用mask. 本人在實際使用該損失函數時,一共使用了兩種mask, 分別mask每條數據全部候選文檔和每條數據的相關文檔, 供你們參考使用.

3.3 效果評估和使用經驗

因爲評測數據使用的是內部數據, 代碼和數據都沒法公開, 所以只能對使用效果作簡單總結:

  1. 效果優於PointwisePairwise, 但差距不是特別大
  2. 相比Pairwise收斂速度極快, 訓練一輪基本就能夠達到最佳效果

下面是我的使用經驗:

  1. 該損失函數比較佔用顯存, 實際的batch_size是batch_size*(m+n), 建議顯存在12G以上
  2. 負例數量越多,效果越好, 收斂也越快
  3. 用pytorch實現log_softmax時, 不要本身實現, 直接使用torch中的log_softmax函數, 它的效率更高些.
  4. 只有一個正例, 還能夠考慮轉爲分類問題,使用交叉熵作優化, 效果一樣較好

4 總結

該損失函數仍是比較簡單的, 只須要簡單的數學知識就能夠自行推導, 在實際使用中也取得了較好的效果, 但願也可以幫助到你們. 若是你們有更好的作法歡迎告訴我.

文章能夠轉載, 但請註明出處:

相關文章
相關標籤/搜索