如何利用PyTorch中的Moco-V2減小計算約束

做者|GUEST
編譯|VK
來源|Analytics Vidhyapython

介紹

SimCLR論文(http://cse.iitkgp.ac.in/~aras...,而且若是有足夠的計算能力,能夠產生與監督模型相似的結果。git

可是這些需求使得框架的計算量至關大。若是咱們能夠擁有這個框架的簡單性和強大功能,而且有更少的計算需求,這樣每一個人均可以訪問它,這不是很好嗎?Moco-v2前來救援。github

注意:在以前的一篇博文中,咱們在PyTorch中實現了SimCLR框架,它是在一個包含5個類別的簡單數據集上實現的,總共只有1250個訓練圖像。算法

數據集

此次咱們將在Pytorch中在更大的數據集上實現Moco-v2,並在Google Colab上訓練咱們的模型。此次咱們將使用Imagenette和Imagewoof數據集網絡

來自Imagenette數據集的一些圖像數據結構

這些數據集的快速摘要(更多信息在這裏:https://github.com/fastai/ima...):架構

  • Imagenette由Imagenet的10個容易分類的類組成,總共有9479個訓練圖像和3935個驗證集圖像。
  • Imagewoof是一個由Imagenet提供的10個難分類組成的數據集,由於全部的類都是狗的品種。總共有9035個訓練圖像,3939個驗證集圖像。

對比學習

對比學習在自我監督學習中的做用是基於這樣一個理念:咱們但願同一類別中不一樣的圖像觀具備類似的表徵。可是,因爲咱們不知道哪些圖像屬於同一類別,一般所作的是將同一圖像的不一樣外觀的表示拉近。咱們把這些不一樣的外觀稱爲正對(positive pairs)。app

另外,咱們但願不一樣類別的圖像有不一樣的外觀,使它們的表徵彼此遠離。不一樣圖像的不一樣外觀的呈現與類別無關,會被彼此推開。咱們把這些不一樣的外觀稱爲負對(negative pairs)。框架

在這種狀況下,一個圖像的前景是什麼?前景能夠被認爲是以一種通過修改的方式看待圖像的某些部分,它本質上是圖像的一種變換。機器學習

根據手頭的任務,有些轉換能夠比其餘轉換工做得更好。SimCLR代表,應用隨機裁剪和顏色抖動能夠很好地完成各類任務,包括圖像分類。這本質上來自於網格搜索,從旋轉、裁剪、剪切、噪聲、模糊、Sobel濾波等選項中選擇一對變換。

從外觀到表示空間的映射是經過神經網絡完成的,一般,resnet用於此目的。下面是從圖像到表示的管道

負對是如何產生的?

在同一幅圖像中,因爲隨機裁剪,咱們能夠獲得多個表示。這樣,咱們就能夠產生正對。

可是如何生成負對呢?負對是來自不一樣圖像的表示。SimCLR論文在同一批中建立了這些。若是一個批包含N個圖像,那麼對於每一個圖像,咱們將獲得2個表示,這總共佔2*N個表示。對於一個特定的表示x,有一個表示與x造成正對(與x來自同一個圖像的表示),其他全部表示(正好是2*N–2)與x造成負對。

若是咱們手頭有大量的負樣本,這些表示就會獲得改善。可是,在SimCLR中,只有當批量較大時,才能實現大量的負樣本,這致使了對計算能力的更高要求。MoCo-v2提供了生成負樣本的另外一種方法。讓咱們詳細瞭解一下。

動態詞典

咱們能夠用一種稍微不一樣的方式來看待對比學習方法,即將查詢與鍵進行匹配。咱們如今有兩個編碼器,一個用於查詢,另外一個用於鍵。此外,爲了獲得大量的負樣本,咱們須要一個大的鍵編碼字典。

此上下文中的正對錶示查詢與鍵匹配。若是查詢和鍵都來自同一個圖像,則它們匹配。編碼的查詢應該與其匹配的鍵類似,而與其餘查詢不一樣。

對於負對,咱們維護一個大字典,其中包含之前批處理的編碼鍵。它們做爲查詢的負樣本。咱們以隊列的形式維護字典。新的batch被入隊,較早的batch被出列。經過更改此隊列的大小,能夠更改負採樣數。

這種方法的挑戰

  • 隨着鍵編碼器的更改,在稍後時間點排隊的鍵可能與較早排隊的鍵不一致。爲了使用對比學習方法,與查詢進行比較的全部鍵必須來自相同或類似的編碼器,這樣比較纔會有意義且一致。
  • 另外一個挑戰是,使用反向傳播學習編碼器參數是不可行的,由於這將須要計算隊列中全部樣本的梯度(這將致使大的計算圖)。

爲了解決這兩個問題,MoCo將鍵編碼器實現爲基於動量的查詢編碼器的移動平均值[1]。這意味着它以這種方式更新關鍵編碼器參數:

其中m很是接近於1(例如,典型值爲0.999),這確保咱們在不一樣的時間從類似的編碼器得到編碼鍵。

損失函數-InfoNCE

咱們但願查詢接近其全部正樣本,遠離全部負樣本。InfoNC函數E會捕獲它。它表明信息噪聲對比估計。對於查詢q和鍵k,InfoNCE損失函數是:

咱們能夠重寫爲:

當q和k的類似性增大,q與負樣本的類似性減少時,損失值減少

如下是損失函數的代碼:

τ = 0.05

def loss_function(q, k, queue):

    # N是批量大小
    N = q.shape[0]
    
    # C是表示的維數
    C = q.shape[1]

    # bmm表明批處理矩陣乘法
    # 若是mat1是b×n×m張量,那麼mat2是b×m×p張量,
    # 而後輸出一個b×n×p張量。
    pos = torch.exp(torch.div(torch.bmm(q.view(N,1,C), k.view(N,C,1)).view(N, 1),τ))
    
    # 在查詢和隊列張量之間執行矩陣乘法
    neg = torch.sum(torch.exp(torch.div(torch.mm(q.view(N,C), torch.t(queue)),τ)), dim=1)
   
    # 求和
    denominator = neg + pos

    return torch.mean(-torch.log(torch.div(pos,denominator)))

讓咱們再看看這個損失函數,並將它與分類交叉熵損失函數進行比較。

這裏predᵢ是數據點在第i類中的機率值預測,trueᵢ是該點屬於第i類的實際機率值(能夠是模糊的,但大多數狀況下是一個one-hot)。

若是你不熟悉這個話題,你能夠看這個視頻來更好地理解交叉熵。另外,請注意,咱們常常經過softmax這樣的函數將分數轉換爲機率值:https://www.youtube.com/watch...

咱們能夠把信息損失函數看做交叉熵損失。數據樣本「q」的正確類是第r類,底層分類器基於softmax,它試圖在K+1類之間進行分類。

Info-NCE還與編碼表示之間的相互信息有關;關於這一點的更多細節見[4]。

MoCo-v2框架

如今,讓咱們把全部的東西放在一塊兒,看看整個Moco-v2算法是什麼樣子的。

步驟1:

咱們必須獲得查詢和鍵編碼器。最初,鍵編碼器具備與查詢編碼器相同的參數。它們是彼此的複製品。隨着訓練的進行,鍵編碼器將成爲查詢編碼器的移動平均值(在這一點上進展緩慢)。

因爲計算能力的限制,咱們使用Resnet-18體系結構來實現。在一般的resnet架構之上,咱們添加了一些密集的層,以使表示的維數降到25。這些層中的某些層稍後將充當投影。

# 定義咱們的深度學習架構
resnetq = resnet18(pretrained=False)

classifier = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(resnetq.fc.in_features, 100)),
    ('added_relu1', nn.ReLU(inplace=True)),
    ('fc2', nn.Linear(100, 50)),
    ('added_relu2', nn.ReLU(inplace=True)),
    ('fc3', nn.Linear(50, 25))
]))

resnetq.fc = classifier
resnetk = copy.deepcopy(resnetq)

# 將resnet架構遷移到設備
resnetq.to(device)
resnetk.to(device)
步驟2:

如今,咱們已經有了編碼器,而且假設咱們已經設置了其餘重要的數據結構,如今是時候開始訓練循環並理解管道了。

這一步是從訓練批中獲取編碼查詢和鍵。咱們用L2範數對錶示進行規範化。

只是一個約定警告,全部後續步驟中的代碼都將位於批處理和epoch循環中。咱們還將張量「k」從它的梯度中分離出來,由於咱們不須要計算圖中的鍵編碼器部分,由於動量更新方程會更新鍵編碼器。

# 梯度零化
optimizer.zero_grad()

# 檢索xq和xk這兩個圖像batch
xq = sample_batched['image1']
xk = sample_batched['image2']

# 把它們移到設備上
xq = xq.to(device)
xk = xk.to(device)

# 獲取他們的輸出
q = resnetq(xq)
k = resnetk(xk)
k = k.detach()

# 將輸出規範化,使它們成爲單位向量
q = torch.div(q,torch.norm(q,dim=1).reshape(-1,1))
k = torch.div(k,torch.norm(k,dim=1).reshape(-1,1))
步驟3:

如今,咱們將查詢、鍵和隊列傳遞給前面定義的loss函數,並將值存儲在一個列表中。而後,像往常同樣,對損失值調用backward函數並運行優化器。

# 得到損失值
loss = loss_function(q, k, queue)

# 把這個損失值放到epoch損失列表中
epoch_losses_train.append(loss.cpu().data.item())

# 反向傳播
loss.backward()

# 運行優化器
optimizer.step()
步驟4:

咱們將最新的batch加入咱們的隊列。若是咱們的隊列大小大於咱們定義的最大隊列大小(K),那麼咱們就從其中取出最老的batch。可使用torch.cat進行隊列操做。

# 更新隊列
queue = torch.cat((queue, k), 0) 

# 若是隊列大於最大隊列大小(k),則出列
# batch大小是256,能夠用變量替換
if queue.shape[0] > K:
    queue = queue[256:,:]
步驟5:

如今咱們進入訓練循環的最後一步,即更新鍵編碼器。咱們使用下面的for循環來實現這一點。

# 更新resnet
for θ_k, θ_q in zip(resnetk.parameters(), resnetq.parameters()):
    θ_k.data.copy_(momentum*θ_k.data + θ_q.data*(1.0 - momentum))
一些訓練細節

訓練resnet-18模型的Imagenette和Imagewoof數據集的GPU時間接近18小時。爲此,咱們使用了googlecolab的GPU(16GB)。咱們使用的batch大小爲256,tau值爲0.05,學習率爲0.001,最終下降到1e-5,權重衰減爲1e-6。咱們的隊列大小爲8192,鍵編碼器的動量值爲0.999。

結果

前3層(將relu視爲一層)定義了投影頭,咱們將其移除用於圖像分類的下游任務。在剩下的網絡上,咱們訓練了一個線性分類器。

咱們獲得了64.2%的正確率,而使用10%的標記訓練數據,使用MoCo-v2。相比之下,使用最早進的監督學習方法,其準確率接近95%。

對於Imagewoof,咱們對10%的標記數據獲得了38.6%的準確率。在這個數據集上進行對比學習的效果低於咱們的預期。咱們懷疑這是由於首先,數據集很是困難,由於全部類都是狗類。

其次,咱們認爲顏色是這些類的一個重要的區別特徵。應用顏色抖動可能會致使來自不一樣類的多個圖像彼此混合表示。相比之下,監督方法的準確率接近90%。

可以彌合自監督模型和監督模型之間差距的設計變動

  1. 使用更大更寬的模型。
  2. 經過使用更大的批量和字典大小。
  3. 使用更多的數據,若是能夠的話。同時引入全部未標記的數據。
  4. 在大量數據上訓練大型模型,而後提取它們。

一些有用的連接:

參考引用

  1. Momentum Contrast for Unsupervised Visual Representation Learning, Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick(https://arxiv.org/pdf/1911.05...
  2. Improved Baselines with Momentum Contrastive Learning, Xinlei Chen, Haoqi Fan, Ross Girshick, and Kaiming He(https://arxiv.org/pdf/2003.04...
  3. A simple framework for contrastive learning of visual representations, Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey E. Hinton.(https://arxiv.org/pdf/2002.05...
  4. Representation Learning with Contrastive Predictive Coding, Aaron van den Oord, Yazhe Li, and Oriol Vinyals(https://arxiv.org/pdf/1807.03...

原文連接:https://www.analyticsvidhya.c...

歡迎關注磐創AI博客站:
http://panchuang.net/

sklearn機器學習中文官方文檔:
http://sklearn123.com/

歡迎關注磐創博客資源彙總站:
http://docs.panchuang.net/

相關文章
相關標籤/搜索