做者|GUEST
編譯|VK
來源|Analytics Vidhyapython
SimCLR論文(http://cse.iitkgp.ac.in/~aras...,而且若是有足夠的計算能力,能夠產生與監督模型相似的結果。git
可是這些需求使得框架的計算量至關大。若是咱們能夠擁有這個框架的簡單性和強大功能,而且有更少的計算需求,這樣每一個人均可以訪問它,這不是很好嗎?Moco-v2前來救援。github
注意:在以前的一篇博文中,咱們在PyTorch中實現了SimCLR框架,它是在一個包含5個類別的簡單數據集上實現的,總共只有1250個訓練圖像。算法
此次咱們將在Pytorch中在更大的數據集上實現Moco-v2,並在Google Colab上訓練咱們的模型。此次咱們將使用Imagenette和Imagewoof數據集網絡
來自Imagenette數據集的一些圖像數據結構
這些數據集的快速摘要(更多信息在這裏:https://github.com/fastai/ima...):架構
對比學習在自我監督學習中的做用是基於這樣一個理念:咱們但願同一類別中不一樣的圖像觀具備類似的表徵。可是,因爲咱們不知道哪些圖像屬於同一類別,一般所作的是將同一圖像的不一樣外觀的表示拉近。咱們把這些不一樣的外觀稱爲正對(positive pairs)。app
另外,咱們但願不一樣類別的圖像有不一樣的外觀,使它們的表徵彼此遠離。不一樣圖像的不一樣外觀的呈現與類別無關,會被彼此推開。咱們把這些不一樣的外觀稱爲負對(negative pairs)。框架
在這種狀況下,一個圖像的前景是什麼?前景能夠被認爲是以一種通過修改的方式看待圖像的某些部分,它本質上是圖像的一種變換。機器學習
根據手頭的任務,有些轉換能夠比其餘轉換工做得更好。SimCLR代表,應用隨機裁剪和顏色抖動能夠很好地完成各類任務,包括圖像分類。這本質上來自於網格搜索,從旋轉、裁剪、剪切、噪聲、模糊、Sobel濾波等選項中選擇一對變換。
從外觀到表示空間的映射是經過神經網絡完成的,一般,resnet用於此目的。下面是從圖像到表示的管道
在同一幅圖像中,因爲隨機裁剪,咱們能夠獲得多個表示。這樣,咱們就能夠產生正對。
可是如何生成負對呢?負對是來自不一樣圖像的表示。SimCLR論文在同一批中建立了這些。若是一個批包含N個圖像,那麼對於每一個圖像,咱們將獲得2個表示,這總共佔2*N個表示。對於一個特定的表示x,有一個表示與x造成正對(與x來自同一個圖像的表示),其他全部表示(正好是2*N–2)與x造成負對。
若是咱們手頭有大量的負樣本,這些表示就會獲得改善。可是,在SimCLR中,只有當批量較大時,才能實現大量的負樣本,這致使了對計算能力的更高要求。MoCo-v2提供了生成負樣本的另外一種方法。讓咱們詳細瞭解一下。
咱們能夠用一種稍微不一樣的方式來看待對比學習方法,即將查詢與鍵進行匹配。咱們如今有兩個編碼器,一個用於查詢,另外一個用於鍵。此外,爲了獲得大量的負樣本,咱們須要一個大的鍵編碼字典。
此上下文中的正對錶示查詢與鍵匹配。若是查詢和鍵都來自同一個圖像,則它們匹配。編碼的查詢應該與其匹配的鍵類似,而與其餘查詢不一樣。
對於負對,咱們維護一個大字典,其中包含之前批處理的編碼鍵。它們做爲查詢的負樣本。咱們以隊列的形式維護字典。新的batch被入隊,較早的batch被出列。經過更改此隊列的大小,能夠更改負採樣數。
爲了解決這兩個問題,MoCo將鍵編碼器實現爲基於動量的查詢編碼器的移動平均值[1]。這意味着它以這種方式更新關鍵編碼器參數:
其中m很是接近於1(例如,典型值爲0.999),這確保咱們在不一樣的時間從類似的編碼器得到編碼鍵。
咱們但願查詢接近其全部正樣本,遠離全部負樣本。InfoNC函數E會捕獲它。它表明信息噪聲對比估計。對於查詢q和鍵k,InfoNCE損失函數是:
咱們能夠重寫爲:
當q和k的類似性增大,q與負樣本的類似性減少時,損失值減少
如下是損失函數的代碼:
τ = 0.05 def loss_function(q, k, queue): # N是批量大小 N = q.shape[0] # C是表示的維數 C = q.shape[1] # bmm表明批處理矩陣乘法 # 若是mat1是b×n×m張量,那麼mat2是b×m×p張量, # 而後輸出一個b×n×p張量。 pos = torch.exp(torch.div(torch.bmm(q.view(N,1,C), k.view(N,C,1)).view(N, 1),τ)) # 在查詢和隊列張量之間執行矩陣乘法 neg = torch.sum(torch.exp(torch.div(torch.mm(q.view(N,C), torch.t(queue)),τ)), dim=1) # 求和 denominator = neg + pos return torch.mean(-torch.log(torch.div(pos,denominator)))
讓咱們再看看這個損失函數,並將它與分類交叉熵損失函數進行比較。
這裏predᵢ是數據點在第i類中的機率值預測,trueᵢ是該點屬於第i類的實際機率值(能夠是模糊的,但大多數狀況下是一個one-hot)。
若是你不熟悉這個話題,你能夠看這個視頻來更好地理解交叉熵。另外,請注意,咱們常常經過softmax這樣的函數將分數轉換爲機率值:https://www.youtube.com/watch...
咱們能夠把信息損失函數看做交叉熵損失。數據樣本「q」的正確類是第r類,底層分類器基於softmax,它試圖在K+1類之間進行分類。
Info-NCE還與編碼表示之間的相互信息有關;關於這一點的更多細節見[4]。
如今,讓咱們把全部的東西放在一塊兒,看看整個Moco-v2算法是什麼樣子的。
咱們必須獲得查詢和鍵編碼器。最初,鍵編碼器具備與查詢編碼器相同的參數。它們是彼此的複製品。隨着訓練的進行,鍵編碼器將成爲查詢編碼器的移動平均值(在這一點上進展緩慢)。
因爲計算能力的限制,咱們使用Resnet-18體系結構來實現。在一般的resnet架構之上,咱們添加了一些密集的層,以使表示的維數降到25。這些層中的某些層稍後將充當投影。
# 定義咱們的深度學習架構 resnetq = resnet18(pretrained=False) classifier = nn.Sequential(OrderedDict([ ('fc1', nn.Linear(resnetq.fc.in_features, 100)), ('added_relu1', nn.ReLU(inplace=True)), ('fc2', nn.Linear(100, 50)), ('added_relu2', nn.ReLU(inplace=True)), ('fc3', nn.Linear(50, 25)) ])) resnetq.fc = classifier resnetk = copy.deepcopy(resnetq) # 將resnet架構遷移到設備 resnetq.to(device) resnetk.to(device)
如今,咱們已經有了編碼器,而且假設咱們已經設置了其餘重要的數據結構,如今是時候開始訓練循環並理解管道了。
這一步是從訓練批中獲取編碼查詢和鍵。咱們用L2範數對錶示進行規範化。
只是一個約定警告,全部後續步驟中的代碼都將位於批處理和epoch循環中。咱們還將張量「k」從它的梯度中分離出來,由於咱們不須要計算圖中的鍵編碼器部分,由於動量更新方程會更新鍵編碼器。
# 梯度零化 optimizer.zero_grad() # 檢索xq和xk這兩個圖像batch xq = sample_batched['image1'] xk = sample_batched['image2'] # 把它們移到設備上 xq = xq.to(device) xk = xk.to(device) # 獲取他們的輸出 q = resnetq(xq) k = resnetk(xk) k = k.detach() # 將輸出規範化,使它們成爲單位向量 q = torch.div(q,torch.norm(q,dim=1).reshape(-1,1)) k = torch.div(k,torch.norm(k,dim=1).reshape(-1,1))
如今,咱們將查詢、鍵和隊列傳遞給前面定義的loss函數,並將值存儲在一個列表中。而後,像往常同樣,對損失值調用backward函數並運行優化器。
# 得到損失值 loss = loss_function(q, k, queue) # 把這個損失值放到epoch損失列表中 epoch_losses_train.append(loss.cpu().data.item()) # 反向傳播 loss.backward() # 運行優化器 optimizer.step()
咱們將最新的batch加入咱們的隊列。若是咱們的隊列大小大於咱們定義的最大隊列大小(K),那麼咱們就從其中取出最老的batch。可使用torch.cat進行隊列操做。
# 更新隊列 queue = torch.cat((queue, k), 0) # 若是隊列大於最大隊列大小(k),則出列 # batch大小是256,能夠用變量替換 if queue.shape[0] > K: queue = queue[256:,:]
如今咱們進入訓練循環的最後一步,即更新鍵編碼器。咱們使用下面的for循環來實現這一點。
# 更新resnet for θ_k, θ_q in zip(resnetk.parameters(), resnetq.parameters()): θ_k.data.copy_(momentum*θ_k.data + θ_q.data*(1.0 - momentum))
訓練resnet-18模型的Imagenette和Imagewoof數據集的GPU時間接近18小時。爲此,咱們使用了googlecolab的GPU(16GB)。咱們使用的batch大小爲256,tau值爲0.05,學習率爲0.001,最終下降到1e-5,權重衰減爲1e-6。咱們的隊列大小爲8192,鍵編碼器的動量值爲0.999。
前3層(將relu視爲一層)定義了投影頭,咱們將其移除用於圖像分類的下游任務。在剩下的網絡上,咱們訓練了一個線性分類器。
咱們獲得了64.2%的正確率,而使用10%的標記訓練數據,使用MoCo-v2。相比之下,使用最早進的監督學習方法,其準確率接近95%。
對於Imagewoof,咱們對10%的標記數據獲得了38.6%的準確率。在這個數據集上進行對比學習的效果低於咱們的預期。咱們懷疑這是由於首先,數據集很是困難,由於全部類都是狗類。
其次,咱們認爲顏色是這些類的一個重要的區別特徵。應用顏色抖動可能會致使來自不一樣類的多個圖像彼此混合表示。相比之下,監督方法的準確率接近90%。
可以彌合自監督模型和監督模型之間差距的設計變動:
一些有用的連接:
原文連接:https://www.analyticsvidhya.c...
歡迎關注磐創AI博客站:
http://panchuang.net/
sklearn機器學習中文官方文檔:
http://sklearn123.com/
歡迎關注磐創博客資源彙總站:
http://docs.panchuang.net/