圖網絡(GN)在深度學習短板即因果推理上擁有巨大潛力,很有可能成爲機器學習領域的下一個增長點,而圖神經網絡(GNN)正屬於圖網絡的子集。GNN近期在圖分類任務上得到了當前最佳的結果,但其存在平面化的侷限,因而不能將圖分層表徵。現實應用中,很多圖信息都是層級表徵的,例如地圖、概念圖、流程圖等,捕獲層級信息將能更加完整高效地表徵圖,應用價值很高。在本文中,來自斯坦福等大學的研究者通過在GNN中結合一種類似CNN中空間池化的操作——可微池化,實現了圖的分層表徵DIFFPOOL在深度GNN的每一層針對節點學習可微分的軟簇分配,將節點映射到一組簇中去,然後這些簇作爲粗化輸入,輸入到GNN下一層。
介紹
近年來人們開發圖神經網絡的興趣持續激增。圖神經網絡,即可以在比如社交網絡數據或分子結構數據等圖結構數據上運行的通用深度學習架構。GNN一般是將底層圖作爲計算圖,通過在圖上傳遞、轉換和聚合節點特徵信息學習神經網絡基元以生成單個節點嵌入。生成的節點嵌入可以作爲輸入,用於如節點分類或連接預測的任何可微預測層,完整的模型可以通過端到端的方式訓練。
然而,現有的GNN結構的主要限制在於太過平坦,因爲它們僅通過圖的邊傳播信息,無法以分層的方式推斷和聚合信息。例如,爲了成功編碼有機分子的圖結構,就要編碼局部分子結構(如單個的原子和與這些原子直接相連的鍵)和分子圖的粗粒結構(如在分子中表示功能單元的原子基團和鍵)。對圖分類任務而言缺少分層結構尤其成問題,因爲這類任務是要預測與整個圖相關的標籤。在圖分類任務中應用 GNN,標準的方法是針對圖中所有的節點生成嵌入,然後對這些節點嵌入進行全局池化,如簡單地求和或在數據集上運行神經網絡。這種全局池化方法忽略了可能出現在圖中的層級結構,也阻礙了研究人員針對完整圖的預測任務建立有效的GNN模型。
研究者在此提出了DIFFPOOL,這是一個可以分層和端到端的方式應用於不同圖神經網絡的可微圖池化模塊。DIFFPOOL允許開發可以學習在圖的層級表徵上運行的更深度的GNN模型。他們開發了一個和CNN中的空間池化操作相似的變體,空間池化可以讓深度CNN在一張表徵越來越粗糙的圖上迭代運行。與標準CNN相比,GNN的挑戰在於圖不包含空間局部性的自然概念,也就是說,不能將所有節點簡單地以
[m×m patch]
的方式池化在一張圖上,因爲圖複雜的拓撲結構排除了任何直接、決定性的
[patch]
的定義。此外,與圖像數據不同,圖數據集中包含的圖形節點數和邊數都不同,這使得定義通用的圖池化操作更具挑戰性。
爲了解決上述問題,我們需要一個可以學習如何聚合節點以在底層圖上建立多層級架構的模型。DIFFPOOL在深度GNN的每一層學習了可微分的軟分配,這種軟分配是基於學習到的嵌入,將節點映射爲一組聚類。以該方法爲框架,作者通過分層的方式「堆疊」了 GNN 層建立了深度 GNN:GNN 模塊中
l
層的輸入節點對應GNN模塊中
l−1
層學到的聚類簇。因此,DIFFPOOL的每一層都能使圖越來越粗糙,然後訓練後的DIFFPOOL就可以產生任何輸入圖的層級表徵。本研究展示了DIFFPOOL可以結合到不同的GNN方法中,這使準確率平均提高了7%,並且在五個基準圖分類任務中,有四個達到了當前最佳水平。最後,DIFFPOOL可以學習到與輸入圖中明確定義的集合相對應的可解釋的層級簇。
最近工作
作者的工作是基於最近GNN和圖分類的研究。
GNN的最近工作。最近幾年提出了許多GNN模型,有受CNN啓發的方法,還有RNN、遞歸網絡和循環置信傳播。大部分方法都屬於Gilmer提出的neural message passing框架,在這種觀點下GNN是一種message passing算法,節點特徵和鄰居關係通過GNN而迭代計算出節點表示。Hamilton總結了這個領域的最近進展,Bronstein概述了圖卷積的聯繫。
GNN實現的圖分類。GNN應用於許多任務如節點分類、鏈接預測、圖分類和信息化學。在圖分類背景下所面臨的一個問題是如何更好的通過GNN產生節點嵌入以表示出整個網絡的特徵。一般的方法有,在最後一層對節點嵌入進行簡單求和或平均、引入與所有節點連接的虛擬節點和用深度學習聚合節點嵌入。然而這些方法都有一個限制,不能學習層級表示(所有節點在一個單層進行全局池化),所以不能捕獲現實世界的自然結構。最近有一些方法,用CNN將所有節點嵌入串聯起來,但是這需要節點的拓撲排序。
最後,最近也有一些工作,將GNN和確定性聚類方法結合起來以學習層級圖表示。與這個不同的是,作者的方法是在端對端訓練的框架下自動學習層級結構表示,而不是依靠確定性聚類方法。
提出方法
DIFFPOOL的關鍵想法是在多層GNN結構中引入節點的可微層級池化。這一節概述DIFFPOOL模塊以及如何在GNN中應用。
圖
G
表示爲
(A,F)
,其中
A∈{0,1}n×n
是鄰接矩陣,
F∈Rn×d
是節點特徵矩陣,
d
是每個節點的特徵維數。給定一個帶標籤的圖數據
D={(G1,y1),(G2,y2),…}
,其中
yi
是圖
Gi∈G
的類標籤,任務目標是尋找映射
f:G→Y
。 相對於標準監督學習過程,這裏的困難主要在於如何更好的從輸入的圖中提取特徵,爲了應用深度學習等機器學習方法進行分類,我們需要將每個圖轉換成一個有限維向量。
GNN。在這個工作中,作者以端到端訓練的方式使用GNN學習提取用於圖分類的特徵。GNN使用message passing結構
:H(k)=M(A,H(k−1);θ(k))
,其中
H(k)∈Rn×d
是GNN迭代
k
次後的節點嵌入,
M
是由鄰接矩陣和參數
θ(k)
決定的message傳播函數,
H(k−1)
是上一步message passing產生的節點嵌入。輸入節點嵌入
H(0)
初始化爲節點輸入特徵,
H(0)=F
。
傳播函數
M
有多種實現方式。有一種流行的GNN變種GCN,M的實現方式是將線性變換和ReLU非線性激活結合起來
H(k)=M(A,H(k−1);W(k))=ReLU(D~−12A~D~−12H(k−1)W(k−1))
A~=A+I,D~=∑jA~ij
其中
W(k)∈Rd×d
是需要訓練的權重矩陣。作者提出的可微分池化層能應用到任意GNN模型中,不論
M
以何種方式實現。GNN迭代
K
次產生最終節點嵌入
Z=H(K)∈Rn×d
,其中
K
一般是
2−6
之間,以下論述中忽略GNN的內部結構,並簡單記爲
Z=GNN(A,X)
。
GNN和池化層的堆疊。GNN的實現內部是平面化的,信息只能通過邊傳播。作者的目標是提出一個通用的、端對端可微分的方法,將GNN模塊堆疊爲層級結構。給定原始圖的鄰接矩陣
A∈Rn×n
後可以產生GNN的輸出
Z=GNN(A,X)
,之後給出一個粗化的圖,粗化圖的節點數爲
m<n
,鄰接矩陣爲
A′∈Rm×m
,節點嵌入爲
Z′∈Rm×d
。這個粗化圖作爲下一層GNN的輸入,經過
L
次重複產生越來越粗化的圖,並分別由串聯的GNN進行處理。於是我們的目標是學習如何使用上一層GNN的輸出結果對節點進行聚類或池化,再把聚類或池化所輸出的粗化圖作爲下一層GNN的輸入。設計GNN的池化層是比較困難的,相比於一般的粗化圖任務,不是在一個圖上對節點進行聚類,而是在圖集合上進行層級池化,在推理時要對許多不同的圖結構進行自適應池化。
基於分配學習的可微分池化。上述提到的DIFFPOOL,難點在於使用GNN的輸出學習節點分配的聚類,將L個GNN堆疊起來,可微池化層利用上一個GNN產生的節點嵌入進行節點聚類,從而產生粗化圖,並以端對端方式進行訓練學習。於是GNN產生的節點嵌入,既用於圖分類,又用於層級池化,而這通過大量的圖樣本進行訓練學習。以下先描述,DIFFPOOL在有了節點分配矩陣後具體如何聚類池化,再描述在GNN架構下如何產生分配矩陣。
用分配矩陣進行池化。將
l
層的聚類分配矩陣記爲
S(l)∈Rnl×nl+1
。
S(l)
的每一行代表在
l
層中的
nl
個節點中的一個節點(或一個節點聚類簇),每一列代表
l+1
層中的
nl+1
節點聚類簇,
S(l)
提供了從
l
層的圖節點到
l+1
層的圖節點(聚類簇)的軟分配。
現在我們已經有了
l
層的節點分配矩陣,將這
l
層圖的鄰接矩陣記爲
A(l)
,
l
層圖的節點嵌入記爲
Z(l)
。DIFFPOOL可微池化層在此基礎上產生輸入圖的粗化圖,
(A(l+1),Xl+1)=DIFFPOOL(A(l),Z(l))
,
A(l+1)
是下一層粗化圖的鄰接矩陣,
X(l+1)
是下一層粗化圖的節點(聚類簇)輸入特徵。
X(l+1)=S(l)TZ(l)∈Rnl+1×d(3)
A(l+1)=S(l)TA(l)S(l)∈Rnl+1×nl+1(4)
公式
(3)
是根據分配矩陣,將上一層的節點嵌入
Xl
轉換成下一層的節點(聚類簇)嵌入
X(l+1)
,類似的,公式
(4)
是將上一層的鄰接矩陣
A(l)
轉換成下一層粗化圖的鄰接矩陣
Al+1
。
A(l+1)
是一個
l
層的全連接實值矩陣,而
Al+1ij
代表
l+1
層聚類簇
i
和聚類簇
j
之間的連接強度。類似的,
X(l+1)
的第
i
行代表第
i
個聚類簇的輸入特徵。最後,
A(l+1)
和
X(l+1)
作爲下一層GNN的輸入特徵。
學習併產生分配矩陣。現在描述DIFFPOOL如何產生分配矩陣
S(l)
。我們使用兩個獨立的GNN(嵌入GNN和池化GNN)產生兩個矩陣,
l
層中的嵌入GNN爲
Z(l)=GNNl,emded(A(l),X(l))(5)
將
l
層鄰接矩陣
A(l)
和輸入特徵
X(l)
作爲一個標準GNN的輸入,進而產生一個新的嵌入
Z(l)
。相比之下,池化GNN則使用
A(l)
和
X(l)
產生分配矩陣
S(l)=softmax(GNNl,pool(A(l),X(l)=softmax(GNNl,pool(A(l),X(l)))(6)
其中
softmax
應用於輸出矩陣的每一行。注意到這兩個GNN使用相同的輸入數據,但是具有不同�