做者|Samrat Saha
編譯|VK
來源|Towards Datas Sciencepython
Supervised Contrastive Learning這篇論文在有監督學習、交叉熵損失與有監督對比損失之間進行了大量的討論,以更好地實現圖像表示和分類任務。讓咱們深刻了解一下這篇論文的內容。git
論文指出能夠在image net數據集有1%的改進。github
就架構而言,它是一個很是簡單的網絡resnet 50,具備128維的頭部。若是你想,你也能夠多加幾層。網絡
Code self.encoder = resnet50() self.head = nn.Linear(2048, 128) def forward(self, x): feat = self.encoder(x) #須要對128向量進行標準化 feat = F.normalize(self.head(feat), dim=1) return feat
如圖所示,訓練分兩個階段進行。架構
以上是不言自明的。dom
本文的主要內容是瞭解自監督的對比損失和監督的對比損失。機器學習
從上面的SCL(監督對比損失)圖中能夠看出,貓與任何非貓進行對比。這意味着全部的貓都屬於同一個標籤,都是正數對,任何非貓都是負的。這與三元組數據以及triplet loss的工做原理很是類似。函數
每一張貓的圖片都會被放大,因此即便是從一張貓的圖片中,咱們也會有不少貓。學習
監督對比損失的損失函數,雖然看起來很可怕,但其實很簡單。google
稍後咱們將看到一些代碼,但首先是很是簡單的解釋。每一個z是標準化的128維向量。
也就是說||z||=1
重申一下線性代數中的事實,若是u和v兩個向量正規化,意味着u.v=cos(u和v之間的夾角)
這意味着若是兩個標準化向量相同,它們之間的點乘=1
#嘗試理解下面的代碼 import numpy as np v = np.random.randn(128) v = v/np.linalg.norm(v) print(np.dot(v,v)) print(np.linalg.norm(v))
損失函數假設每幅圖像都有一個加強版本,每批有N幅圖像,生成的batch大小= 2*N
在i!=j,yi=yj時,分子exp(zi.zj)/tau表示一批中全部的貓。將i個第128個dim向量zi與全部的j個第128個dim向量點積。
分母是i個貓的圖像點乘其餘不是貓的圖像。取zi和zk的點,使i!=k表示它點乘除它本身之外的全部圖像。
最後,咱們取對數機率,並將其與批處理中除自身外的全部貓圖像相加,而後除以2*N-1
全部圖像的總損失和
咱們使用一些torch代碼能夠理解上面的內容。
假設咱們的批量大小是4,讓咱們看看如何計算單個批次的損失。
若是批量大小爲4,你在網絡上的輸入將是8x3x224x224,在這裏圖像的寬度和高度爲224。
8=4x2的緣由是咱們對每一個圖像老是有一個對比度,所以須要相應地編寫一個數據加載程序。
對比損失resnet將輸出8x128維的矩陣,你能夠分割這些維度以計算批量損失。
#batch大小 bs = 4
這個部分能夠計算分子
temperature = 0.07 anchor_feature = contrast_feature anchor_dot_contrast = torch.div( torch.matmul(anchor_feature, contrast_feature.T), temperature)
咱們的特徵形狀是8x128,讓咱們採起3x128矩陣和轉置,下面是可視化後的圖片。
anchor_feature=3x128和contrast_feature=128x3,結果爲3x3,以下所示
若是你注意到全部的對角線元素都是點自己,這實際上咱們不想要,咱們將刪除他們。
線性代數有個性質:若是u和v是兩個向量,那麼當u=v時,u.v是最大的。所以,在每一行中,若是咱們取錨點對比度的最大值,而且取相同值,則全部對角線將變爲0。
讓咱們把維度從128降到2
#bs 1 和 dim 2 意味着 2*1x2 features = torch.randn(2, 2) temperature = 0.07 contrast_feature = features anchor_feature = contrast_feature anchor_dot_contrast = torch.div( torch.matmul(anchor_feature, contrast_feature.T), temperature) print('anchor_dot_contrast=\n{}'.format(anchor_dot_contrast)) logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) print('logits_max = {}'.format(logits_max)) logits = anchor_dot_contrast - logits_max.detach() print(' logits = {}'.format(logits)) #輸出看看對角線發生了什麼 anchor_dot_contrast= tensor([[128.8697, -12.0467], [-12.0467, 50.5816]]) logits_max = tensor([[128.8697], [ 50.5816]]) logits = tensor([[ 0.0000, -140.9164], [ -62.6283, 0.0000]])
建立人工標籤和建立適當的掩碼進行對比計算。這段代碼有點複雜,因此要仔細檢查輸出。
bs = 4 print('batch size', bs) temperature = 0.07 labels = torch.randint(4, (1,4)) print('labels', labels) mask = torch.eq(labels, labels.T).float() print('mask = \n{}'.format(logits_mask)) #對它進行硬編碼,以使其更容易理解 contrast_count = 2 anchor_count = contrast_count mask = mask.repeat(anchor_count, contrast_count) #屏蔽self-contrast的狀況 logits_mask = torch.scatter( torch.ones_like(mask), 1, torch.arange(bs * anchor_count).view(-1, 1), 0 ) mask = mask * logits_mask print('mask * logits_mask = \n{}'.format(mask))
讓咱們理解輸出。
batch size 4 labels tensor([[3, 0, 2, 3]]) #以上的意思是在這批4個品種的葡萄中,咱們有3,0,2,3個標籤。以防大家忘了咱們在這裏只作了一次對比因此咱們會有3_c 0_c 2_c 3_c做爲輸入批處理中的對比。 mask = tensor([[0., 1., 1., 1., 1., 1., 1., 1.], [1., 0., 1., 1., 1., 1., 1., 1.], [1., 1., 0., 1., 1., 1., 1., 1.], [1., 1., 1., 0., 1., 1., 1., 1.], [1., 1., 1., 1., 0., 1., 1., 1.], [1., 1., 1., 1., 1., 0., 1., 1.], [1., 1., 1., 1., 1., 1., 0., 1.], [1., 1., 1., 1., 1., 1., 1., 0.]]) #這是很是重要的,因此咱們建立了mask = mask * logits_mask,它告訴咱們在第0個圖像表示中,它應該與哪一個圖像進行對比。 # 因此咱們的標籤就是標籤張量([[3,0,2,3]]) # 我從新命名它們是爲了更好地理解張量([[3_1,0_1,2_1,3_2]]) mask * logits_mask = tensor([[0., 0., 0., 1., 1., 0., 0., 1.], [0., 0., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0.], [1., 0., 0., 0., 1., 0., 0., 1.], [1., 0., 0., 1., 0., 0., 0., 1.], [0., 1., 0., 0., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 0.], [1., 0., 0., 1., 1., 0., 0., 0.]])
錨點對比代碼
logits = anchor_dot_contrast — logits_max.detach()
損失函數
數學回顧
咱們已經有了第一部分的點積除以tau做爲logits。
#上述等式的第二部分等於torch.log(exp_logits.sum(1, keepdim=True)) exp_logits = torch.exp(logits) * logits_mask log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) # 計算對數似然的均值 mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) # 損失 loss = - mean_log_prob_pos loss = loss.view(anchor_count, 4).mean() print('19. loss {}'.format(loss))
我認爲這是監督下的對比損失。我認爲如今很容易理解自監督的對比損失,由於它比這更簡單。
根據本文的研究結果,contrast_count越大,模型越清晰。須要修改contrast_count爲2以上,但願你能在上述說明的幫助下嘗試。
參考引用
原文連接:https://towardsdatascience.co...
歡迎關注磐創AI博客站:
http://panchuang.net/
sklearn機器學習中文官方文檔:
http://sklearn123.com/
歡迎關注磐創博客資源彙總站:
http://docs.panchuang.net/