KDD 2019論文解讀:多分類下的模型可解釋性

前言

模型可解釋性是機器學習研究中的一個重要課題。這裏咱們研究的對象是廣義加性模型(Generalized Additive Models,簡稱GAMs)。GAM在醫療等對解釋性要求較高的場景下已經有了普遍的應用 [1]。
GAM做爲一個徹底白盒化的模型提供了比(廣義)線性模型(GLMs)更好的模型表達能力:GAM能對單特徵和雙特徵交叉(pairwise interaction)作非線性的變換。帶pairwiseinteraction的GAM每每被稱爲GA2M。如下是GA2
M模型的數學表達:算法

其中g是linkfunction,fi和fij被稱爲shape function,分別爲模型所須要學習的特徵變換函數。因爲fi和fij都是低緯度的函數,模型中每個函數均可以被可視化出來,從而方便建模人員瞭解每一個特徵是如何影響最終預測的。例如在[1]中,年齡對肺炎致死率的影響就能夠用一張圖來表示。安全

因爲GAM對特徵作了非線性變換,這使得GAM每每能提供比線性模型更強大的建模能力。在一些研究中GAM的效果每每能逼近Boosted Trees或者Random Forests [1, 2, 3]。dom

可視化圖像與模型的預測機制之間的矛盾

本文首先討論了在多分類問題的下,傳統可解釋性算法(例如邏輯迴歸,SVM)的可視化圖像與模型的預測機制之間存在的矛盾。若是直接經過這些未經加工的可視化圖像理解模型預測機制,有可能形成建模人員對模型預測機制的錯誤解讀。如圖1所示,左邊是在一個多分類GAM下age的shape function。粗看之下這張圖表示了Diabetes I的風險隨年齡增加而增長。然而當咱們看實際的預測機率(右圖),Diabetes I的風險其實應該是隨着年齡的增長而下降的。機器學習

爲了解決這一問題,本文提出了一種後期處理方法(AdditivePost-Processing for Interpretability, API),可以對用任意算法訓練的GAM進行處理,使得在不改變模型預測的前提下,處理後模型的可視化圖像與模型的預測機制相符,由此讓建模人員能夠安全的經過傳統的可視化方法來觀察和理解模型的預測機制,而不會被錯誤的視覺信息誤導。函數

多分類下的模型可解釋性

API的設計理念來源於兩個在長期使用GAM的過程當中獲得的可解釋性定理(Axioms of Interpretability)。咱們但願一個GAM模型具有以下兩個性質:學習

任意一個shape function fik (對應feature i和class k)的形狀,必需要和真實的預測機率Pk​的形狀相符,即咱們不但願看到一個shape function是遞增的,但實際上預測機率是遞減的狀況。優化

Shape function應該避免任何沒必要要的不平滑。不平滑的shape function會讓建模人員難以理解模型的預測趨勢。spa

如今咱們知道咱們想要的模型須要知足什麼性質,那麼如何找到這樣的模型,而不改變原模型的預測呢?這裏就要用到一個重要的softmax函數的性質。設計

對於一個softmax函數,若是在每個輸入項中加上同一個函數,由此得來的模型是和原模型徹底等價的。也就是說,這兩個模型在任何狀況下的預測結果都相同。基於這樣的性質,咱們就能夠設計一個g 函數,讓加入g函數以後的模型知足咱們想要的性質。rest

咱們在文章中從數學上證實,以上這個優化問題永遠有惟一的全局最優解,而且咱們給出了這個解的解析形式。咱們基於此設計的後期處理方法幾乎不消耗任何計算資源,卻能夠把具備誤導性的GAM模型轉化成能夠放心觀察的可解釋模型。

在一個預測嬰兒死因的數據上(12類分類問題),咱們採用API對shapefunction作了處理,從而使得他們能真實地反應預測機率變化的趨勢。這裏能夠看到,在採用API以前,模型可視化提供的信息是全部死因都和嬰兒體重和Apgar值成負相關趨勢。可是在採用API以後咱們發現,實際上不一樣的死因與嬰兒體重和Apgar值的關係

是不同的:其中一些死因是正相關,一些死因是負相關,另一些在嬰兒體重和Apgar值達到某個中間值得時候死亡率達到最高。API使得醫療人員可以經過模型獲得更準確的預測信息。

總結

在不少mission-critical的場景下(醫療,金融等),模型可解釋性每每比模型自身的準確性更重要。廣義加性模型做爲一個高精確度又徹底白盒化的模型,預期能在更多的應用場景上落地。

原文連接 本文爲雲棲社區原創內容,未經容許不得轉載。

相關文章
相關標籤/搜索