自監督對比損失和監督對比損失的對比

做者|Samrat Saha
編譯|VK
來源|Towards Datas Sciencepython

Supervised Contrastive Learning這篇論文在有監督學習、交叉熵損失與有監督對比損失之間進行了大量的討論,以更好地實現圖像表示和分類任務。讓咱們深刻了解一下這篇論文的內容。git

論文指出能夠在image net數據集有1%的改進。github

就架構而言,它是一個很是簡單的網絡resnet 50,具備128維的頭部。若是你想,你也能夠多加幾層。網絡

Code

self.encoder = resnet50()

self.head = nn.Linear(2048, 128)

def forward(self, x):
 feat = self.encoder(x)
 #須要對128向量進行標準化
 feat = F.normalize(self.head(feat), dim=1)
 return feat

如圖所示,訓練分兩個階段進行。架構

  • 使用對比損失的訓練集(兩種變化)
  • 凍結參數,而後使用softmax損失在線性層上學習分類器。(來自論文的作法)

以上是不言自明的。dom

本文的主要內容是瞭解自監督的對比損失和監督的對比損失。機器學習

從上面的SCL(監督對比損失)圖中能夠看出,貓與任何非貓進行對比。這意味着全部的貓都屬於同一個標籤,都是正數對,任何非貓都是負的。這與三元組數據以及triplet loss的工做原理很是類似。函數

每一張貓的圖片都會被放大,因此即便是從一張貓的圖片中,咱們也會有不少貓。學習

監督對比損失的損失函數,雖然看起來很可怕,但其實很簡單。google

稍後咱們將看到一些代碼,但首先是很是簡單的解釋。每一個z是標準化的128維向量。

也就是說||z||=1

重申一下線性代數中的事實,若是u和v兩個向量正規化,意味着u.v=cos(u和v之間的夾角)

這意味着若是兩個標準化向量相同,它們之間的點乘=1

#嘗試理解下面的代碼

import numpy as np
v = np.random.randn(128)
v = v/np.linalg.norm(v)
print(np.dot(v,v))
print(np.linalg.norm(v))

損失函數假設每幅圖像都有一個加強版本,每批有N幅圖像,生成的batch大小= 2*N

在i!=j,yi=yj時,分子exp(zi.zj)/tau表示一批中全部的貓。將i個第128個dim向量zi與全部的j個第128個dim向量點積。

分母是i個貓的圖像點乘其餘不是貓的圖像。取zi和zk的點,使i!=k表示它點乘除它本身之外的全部圖像。

最後,咱們取對數機率,並將其與批處理中除自身外的全部貓圖像相加,而後除以2*N-1

全部圖像的總損失和

咱們使用一些torch代碼能夠理解上面的內容。

假設咱們的批量大小是4,讓咱們看看如何計算單個批次的損失。

若是批量大小爲4,你在網絡上的輸入將是8x3x224x224,在這裏圖像的寬度和高度爲224。

8=4x2的緣由是咱們對每一個圖像老是有一個對比度,所以須要相應地編寫一個數據加載程序。

對比損失resnet將輸出8x128維的矩陣,你能夠分割這些維度以計算批量損失。

#batch大小
bs = 4

這個部分能夠計算分子

temperature = 0.07

anchor_feature = contrast_feature

anchor_dot_contrast = torch.div(
    torch.matmul(anchor_feature, contrast_feature.T),
    temperature)

咱們的特徵形狀是8x128,讓咱們採起3x128矩陣和轉置,下面是可視化後的圖片。

anchor_feature=3x128和contrast_feature=128x3,結果爲3x3,以下所示

若是你注意到全部的對角線元素都是點自己,這實際上咱們不想要,咱們將刪除他們。

線性代數有個性質:若是u和v是兩個向量,那麼當u=v時,u.v是最大的。所以,在每一行中,若是咱們取錨點對比度的最大值,而且取相同值,則全部對角線將變爲0。

讓咱們把維度從128降到2

#bs 1 和 dim 2 意味着 2*1x2 
features = torch.randn(2, 2)

temperature = 0.07 
contrast_feature  = features
anchor_feature = contrast_feature
anchor_dot_contrast = torch.div(
    torch.matmul(anchor_feature, contrast_feature.T),
    temperature)
print('anchor_dot_contrast=\n{}'.format(anchor_dot_contrast))

logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
print('logits_max = {}'.format(logits_max))
logits = anchor_dot_contrast - logits_max.detach()
print(' logits = {}'.format(logits))

#輸出看看對角線發生了什麼

anchor_dot_contrast=
tensor([[128.8697, -12.0467],
        [-12.0467,  50.5816]])
 logits_max = tensor([[128.8697],
        [ 50.5816]])
 logits = tensor([[   0.0000, -140.9164],
        [ -62.6283,    0.0000]])

建立人工標籤和建立適當的掩碼進行對比計算。這段代碼有點複雜,因此要仔細檢查輸出。

bs = 4
print('batch size', bs)
temperature = 0.07
labels = torch.randint(4, (1,4))
print('labels', labels)
mask = torch.eq(labels, labels.T).float()
print('mask = \n{}'.format(logits_mask))

#對它進行硬編碼,以使其更容易理解
contrast_count = 2
anchor_count = contrast_count

mask = mask.repeat(anchor_count, contrast_count)

#屏蔽self-contrast的狀況
logits_mask = torch.scatter(
    torch.ones_like(mask),
    1,
    torch.arange(bs * anchor_count).view(-1, 1),
    0
)
mask = mask * logits_mask
print('mask * logits_mask = \n{}'.format(mask))

讓咱們理解輸出。

batch size 4
labels tensor([[3, 0, 2, 3]])

#以上的意思是在這批4個品種的葡萄中,咱們有3,0,2,3個標籤。以防大家忘了咱們在這裏只作了一次對比因此咱們會有3_c 0_c 2_c 3_c做爲輸入批處理中的對比。

mask = 
tensor([[0., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 1.],
        [1., 1., 0., 1., 1., 1., 1., 1.],
        [1., 1., 1., 0., 1., 1., 1., 1.],
        [1., 1., 1., 1., 0., 1., 1., 1.],
        [1., 1., 1., 1., 1., 0., 1., 1.],
        [1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0.]])
        
#這是很是重要的,因此咱們建立了mask = mask * logits_mask,它告訴咱們在第0個圖像表示中,它應該與哪一個圖像進行對比。

# 因此咱們的標籤就是標籤張量([[3,0,2,3]])
# 我從新命名它們是爲了更好地理解張量([[3_1,0_1,2_1,3_2]])

mask * logits_mask = 
tensor([[0., 0., 0., 1., 1., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 0., 1.],
        [1., 0., 0., 1., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 1., 1., 0., 0., 0.]])

錨點對比代碼

logits = anchor_dot_contrast — logits_max.detach()

損失函數

數學回顧

咱們已經有了第一部分的點積除以tau做爲logits。

#上述等式的第二部分等於torch.log(exp_logits.sum(1, keepdim=True))

exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

# 計算對數似然的均值
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

# 損失
loss = - mean_log_prob_pos

loss = loss.view(anchor_count, 4).mean()
print('19. loss {}'.format(loss))

我認爲這是監督下的對比損失。我認爲如今很容易理解自監督的對比損失,由於它比這更簡單。

根據本文的研究結果,contrast_count越大,模型越清晰。須要修改contrast_count爲2以上,但願你能在上述說明的幫助下嘗試。

參考引用

  • [1] : Supervised Contrastive Learning
  • [2] : Florian Schroff, Dmitry Kalenichenko, and James Philbin. Facenet: A unified embedding for face recognition and clustering. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 815–823, 2015.
  • [3] : A Simple Framework for Contrastive Learning of Visual Representations, Ting Chen, Simon Kornblith Mohammad Norouzi, Geoffrey Hinton
  • [4] : https://github.com/google-res...

原文連接:https://towardsdatascience.co...

歡迎關注磐創AI博客站:
http://panchuang.net/

sklearn機器學習中文官方文檔:
http://sklearn123.com/

歡迎關注磐創博客資源彙總站:
http://docs.panchuang.net/

相關文章
相關標籤/搜索