背景與挑戰
在現代深度學習算法中,對未標記數據的手工標註是其主要侷限性之一。爲了訓練一個好的模型,咱們一般須要準備大量的標記數據。在少數類和數據的狀況下,咱們可使用帶有標籤的公共數據集的預訓練模型,並使用本身的數據微調最後幾層便可。
可是,當你的數據很大時(好比商店中的產品或人臉,..),就會很容易遇到問題,而且僅經過幾個可訓練的層很難學習模型,此外,未標記數據(例如,文檔文本,Internet上的圖像)的數量是無限的,爲任務標記全部標籤幾乎是不可能的,可是不使用它們又絕對是一種浪費。
在這種狀況下,就須要使用新的數據集從頭開始訓練深度模型,這時就須要花費大量的時間和精力來標記數據,這也就是自監督學習誕生的緣由。其背後的想法很簡單,主要有兩個任務:
代理任務:深度模型將從沒有註釋的未標記數據中學習可概括的表徵信息,而後利用隱式信息自行生成監督信號。git
3.損失函數
學習目標是一個基於表示對的二元分類問題,所以,咱們可使用二進制交叉熵損失來最大化伯努利對數似然,其中關係分數y表示經過sigmoid激活函數誘導的表示成員的機率估計。
最後,本文[6]還提供了在標準數據集(CIFAR-十、CIFAR-100、CIFAR-100-20、STL-十、tiny-ImageNet、SlimageNet)、不一樣主幹(淺層和深層)、相同的學習進度(epochs)上的關係推理結果,結果以下,欲瞭解更多信息,請查閱他的論文。
實驗評估
在本文中,我想在公共圖像數據集STL-10上重現關係推理系統,該數據集由10個類(飛機、鳥、汽車、貓、鹿、狗、馬、猴、船、卡車)組成,顏色爲96x96像素。
首先,咱們須要導入一些重要的庫算法
import torch import torchvision import torchvision.transforms as transforms from PIL import Image import math import time from torch.utils.data import DataLoader from time import sleep from tqdm import tqdm import numpy as np from fastprogress.fastprogress import master_bar, progress_bar from torchvision import models import matplotlib.pyplot as plt from torchvision.utils import make_grid %config InlineBackend.figure_format = 'svg'
STL-10數據集包含1300個標記圖像(500個用於訓練,800個用於測試),同時它也包括100000個未標記的圖像,這些圖像來自類似但更普遍的分佈,例如,除了標籤集中的動物外,它還包含其餘類型的動物(熊、兔子等)和車輛(火車、公共汽車等)
而後根據做者的建議建立關係推理類網絡
class RelationalReasoning(torch.nn.Module): """自監督關係推理。 方法的基本實現,它使用 「cat」聚合函數(最有效), 可與任何主幹一塊兒使用。 """ def __init__(self, backbone, feature_size=64): super(RelationalReasoning, self).__init__() self.backbone = backbone.to(device) self.relation_head = torch.nn.Sequential( torch.nn.Linear(feature_size*2, 256), torch.nn.BatchNorm1d(256), torch.nn.LeakyReLU(), torch.nn.Linear(256, 1)).to(device) def aggregate(self, features, K): relation_pairs_list = list() targets_list = list() size = int(features.shape[0] / K) shifts_counter=1 for index_1 in range(0, size*K, size): for index_2 in range(index_1+size, size*K, size): # 默認狀況下使用「cat」聚合函數 pos_pair = torch.cat([features[index_1:index_1+size], features[index_2:index_2+size]], 1) # 經過滾動小批無碰撞的洗牌(負) neg_pair = torch.cat([ features[index_1:index_1+size], torch.roll(features[index_2:index_2+size], shifts=shifts_counter, dims=0)], 1) relation_pairs_list.append(pos_pair) relation_pairs_list.append(neg_pair) targets_list.append(torch.ones(size, dtype=torch.float32)) targets_list.append(torch.zeros(size, dtype=torch.float32)) shifts_counter+=1 if(shifts_counter>=size): shifts_counter=1 # avoid identity pairs relation_pairs = torch.cat(relation_pairs_list, 0) targets = torch.cat(targets_list, 0) return relation_pairs.to(device), targets.to(device) def train(self, tot_epochs, train_loader): optimizer = torch.optim.Adam([ {'params': self.backbone.parameters()}, {'params': self.relation_head.parameters()}]) BCE = torch.nn.BCEWithLogitsLoss() self.backbone.train() self.relation_head.train() mb = master_bar(range(1, tot_epochs+1)) for epoch in mb: # 實際目標被丟棄(無監督) train_loss = 0 accuracy_list = list() for data_augmented, _ in progress_bar(train_loader, parent=mb): K = len(data_augmented) # tot augmentations x = torch.cat(data_augmented, 0).to(device) optimizer.zero_grad() # 前向傳播(主幹) features = self.backbone(x) # 聚合函數 relation_pairs, targets = self.aggregate(features, K) # 前向傳播 (關係頭) score = self.relation_head(relation_pairs).squeeze() # 交叉熵損失與向後傳播 loss = BCE(score, targets) loss.backward() optimizer.step() train_loss += loss.item()*K predicted = torch.round(torch.sigmoid(score)) correct = predicted.eq(targets.view_as(predicted)).sum() accuracy = (correct / float(len(targets))).cpu().numpy() accuracy_list.append(accuracy) epoch_loss = train_loss / len(train_loader.sampler) epoch_accuracy = sum(accuracy_list)/len(accuracy_list)*100 mb.write(f"Epoch [{epoch}/{tot_epochs}] - Accuracy: {epoch_accuracy:.2f}% - Loss: {epoch_loss:.4f}")
爲了比較關係推理方法在淺層模型和深層模型上的性能,咱們將建立一個淺層模型(Conv4),並使用深層模型的結構(Resnet34)。app
backbone = Conv4() # 淺層模型 backbone = models.resnet34(pretrained = False) # 深層模型
根據做者的建議,設置了一些超參數和加強策略。咱們將在未標記的STL-10數據集上用關係頭訓練主幹。dom
# 模擬的超參數 K = 16 # tot augmentations, 論文中 K=32 batch_size = 64 # 論文中使用64 tot_epochs = 10 # 論文中使用200 feature_size = 64 # Conv4 主幹的單元數 feature_size = 1000 # Resnet34 主幹的單元數backbone # 擴充策略 normalize = transforms.Normalize(mean=[0.4406, 0.4273, 0.3858], std=[0.2687, 0.2613, 0.2685]) color_jitter = transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2) rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) rnd_gray = transforms.RandomGrayscale(p=0.2) rnd_rcrop = transforms.RandomResizedCrop(size=96, scale=(0.08, 1.0), interpolation=2) rnd_hflip = transforms.RandomHorizontalFlip(p=0.5) train_transform = transforms.Compose([rnd_rcrop, rnd_hflip, rnd_color_jitter, rnd_gray, transforms.ToTensor(), normalize]) # 加載到數據加載器 torch.manual_seed(1) torch.cuda.manual_seed(1) train_set = MultiSTL10(K=K, root='data', split='unlabeled', transform=train_transform, download=True) train_loader = DataLoader(train_set,batch_size=batch_size, shuffle=True,num_workers=2, pin_memory=True)
到目前爲止,咱們已經創造了訓練咱們模型所需的一切,如今咱們將在10個時期和16個加強圖像(K)中訓練主幹和關係頭模型,使用1個GPU Tesla P100-PCIE-16GB在淺層模型(Conv4)上花費了4個小時,在深層模型(Resnet34)上花費了6個小時(你能夠自由地更改時期數以及另外一個超參數以得到更好的結果)機器學習
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") backbone.to(device) model = RelationalReasoning(backbone, feature_size) model.train(tot_epochs=tot_epochs, train_loader=train_loader) torch.save(model.backbone.state_dict(), 'model.tar')
在訓練了咱們的主幹模型以後,咱們丟棄了關係頭,只將主幹用於下游任務。咱們須要使用STL-10(500個圖像)中的標記數據來微調咱們的主幹,並在測試集中測試最終的模型(800個圖像)。訓練和測試數據集將加載到Dataloader中,而無需進行擴充。ide
# set random seed torch.manual_seed(1) torch.cuda.manual_seed(1) # no augmentations used for linear evaluation transform_lineval = transforms.Compose([transforms.ToTensor(), normalize]) # Download STL10 labeled train and test dataset train_set_lineval = torchvision.datasets.STL10('data', split='train', transform=transform_lineval) test_set_lineval = torchvision.datasets.STL10('data', split='test', transform=transform_lineval) # Load dataset in data loader train_loader_lineval = DataLoader(train_set_lineval, batch_size=128, shuffle=True) test_loader_lineval = DataLoader(test_set_lineval, batch_size=128, shuffle=False)
咱們將加載預訓練的主幹模型,並使用一個簡單的線性模型將輸出特性與數據集中的許多類鏈接起來。svg
# linear model linear_layer = torch.nn.Linear(64, 10) # if backbone is Conv4 linear_layer = torch.nn.Linear(1000, 10) # if backbone is Resnet34 # defining a raw backbone model backbone_lineval = Conv4() # Conv4 backbone_lineval = models.resnet34(pretrained = False) # Resnet34 # load model checkpoint = torch.load('model.tar') # name of pretrain weight backbone_lineval.load_state_dict(checkpoint)
此時,只訓練線性模型,凍結主幹模型。首先,咱們將看到微調Conv4的結果函數
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") optimizer = torch.optim.Adam(linear_layer.parameters()) CE = torch.nn.CrossEntropyLoss() linear_layer.to(device) linear_layer.train() backbone_lineval.to(device) backbone_lineval.eval() print('Linear evaluation') for epoch in range(20): accuracy_list = list() for i, (data, target) in enumerate(train_loader_lineval): optimizer.zero_grad() data = data.to(device) target= target.to(device) output = backbone_lineval(data).to(device).detach() output = linear_layer(output) loss = CE(output, target) loss.backward() optimizer.step() # estimate the accuracy prediction = output.argmax(-1) correct = prediction.eq(target.view_as(prediction)).sum() accuracy = (100.0 * correct / len(target)) accuracy_list.append(accuracy.item()) print('Epoch [{}] loss: {:.5f}; accuracy: {:.2f}%' \ .format(epoch+1, loss.item(), sum(accuracy_list)/len(accuracy_list))) Linear evaluation Epoch [1] loss: 2.24857; accuracy: 14.77% Epoch [2] loss: 2.23015; accuracy: 24.49% Epoch [3] loss: 2.18529; accuracy: 32.46% Epoch [4] loss: 2.24595; accuracy: 36.45% Epoch [5] loss: 2.09482; accuracy: 42.46% Epoch [6] loss: 2.11192; accuracy: 43.40% Epoch [7] loss: 2.05064; accuracy: 47.29% Epoch [8] loss: 2.03494; accuracy: 47.38% Epoch [9] loss: 1.91709; accuracy: 47.46% Epoch [10] loss: 1.99181; accuracy: 48.03% Epoch [11] loss: 1.91527; accuracy: 48.28% Epoch [12] loss: 1.93190; accuracy: 49.55% Epoch [13] loss: 2.00492; accuracy: 49.71% Epoch [14] loss: 1.85328; accuracy: 49.94% Epoch [15] loss: 1.88910; accuracy: 49.86% Epoch [16] loss: 1.88084; accuracy: 50.76% Epoch [17] loss: 1.63443; accuracy: 50.74% Epoch [18] loss: 1.76303; accuracy: 50.62% Epoch [19] loss: 1.70486; accuracy: 51.46% Epoch [20] loss: 1.61629; accuracy: 51.84% 而後檢查測試集 accuracy_list = list() for i, (data, target) in enumerate(test_loader_lineval): data = data.to(device) target= target.to(device) output = backbone_lineval(data).detach() output = linear_layer(output) # estimate the accuracy prediction = output.argmax(-1) correct = prediction.eq(target.view_as(prediction)).sum() accuracy = (100.0 * correct / len(target)) accuracy_list.append(accuracy.item()) print('Test accuracy: {:.2f}%'.format(sum(accuracy_list)/len(accuracy_list))) Test accuracy: 49.98% Conv4在測試集上得到了49.98%的準確率,這意味着主幹模型能夠在未標記的數據集中學習有用的特徵,只需在不多的時間段內進行微調就能夠達到很好的效果。如今讓咱們檢查深度模型的性能。 device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") optimizer = torch.optim.Adam(linear_layer.parameters()) CE = torch.nn.CrossEntropyLoss() linear_layer.to(device) linear_layer.train() backbone_lineval.to(device) backbone_lineval.eval() print('Linear evaluation') for epoch in range(20): accuracy_list = list() for i, (data, target) in enumerate(train_loader_lineval): optimizer.zero_grad() data = data.to(device) target= target.to(device) output = backbone_lineval(data).to(device).detach() output = linear_layer(output) loss = CE(output, target) loss.backward() optimizer.step() # estimate the accuracy prediction = output.argmax(-1) correct = prediction.eq(target.view_as(prediction)).sum() accuracy = (100.0 * correct / len(target)) accuracy_list.append(accuracy.item()) print('Epoch [{}] loss: {:.5f}; accuracy: {:.2f}%' \ .format(epoch+1, loss.item(), sum(accuracy_list)/len(accuracy_list)))
Linear evaluation Epoch [1] loss: 2.68060; accuracy: 47.79% Epoch [2] loss: 1.56714; accuracy: 58.34% Epoch [3] loss: 1.18530; accuracy: 56.50% Epoch [4] loss: 0.94784; accuracy: 57.91% Epoch [5] loss: 1.48861; accuracy: 57.56% Epoch [6] loss: 0.91673; accuracy: 57.87% Epoch [7] loss: 0.90533; accuracy: 58.96% Epoch [8] loss: 2.10333; accuracy: 57.40% Epoch [9] loss: 1.58732; accuracy: 55.57% Epoch [10] loss: 0.88780; accuracy: 57.79% Epoch [11] loss: 0.93859; accuracy: 58.44% Epoch [12] loss: 1.15898; accuracy: 57.32% Epoch [13] loss: 1.25100; accuracy: 57.79% Epoch [14] loss: 0.85337; accuracy: 59.06% Epoch [15] loss: 1.62060; accuracy: 58.91% Epoch [16] loss: 1.30841; accuracy: 58.95% Epoch [17] loss: 0.27441; accuracy: 58.11% Epoch [18] loss: 1.58133; accuracy: 58.73% Epoch [19] loss: 0.76258; accuracy: 58.81% Epoch [20] loss: 0.62280; accuracy: 58.50%
而後評估測試數據集性能
accuracy_list = list() for i, (data, target) in enumerate(test_loader_lineval): data = data.to(device) target= target.to(device) output = backbone_lineval(data).detach() output = linear_layer(output) # estimate the accuracy prediction = output.argmax(-1) correct = prediction.eq(target.view_as(prediction)).sum() accuracy = (100.0 * correct / len(target)) accuracy_list.append(accuracy.item()) print('Test accuracy: {:.2f}%'.format(sum(accuracy_list)/len(accuracy_list)))
Test accuracy: 55.38%
結果顯示,咱們能夠在測試集上得到55.38%的精度。本文的主要目的是重現和評估關係推理方法論,以指導模型識別無標籤對象,所以,這些結果是很是有用的。若是你以爲不滿意,你能夠經過改變超參數來自由地作進行實驗,好比增長數量,時期,或者改變模型結構。
最後的想法
自監督關係推理在定量和定性兩方面都是有效的,而且具備從淺到深的不一樣大小的主幹。經過比較學習到的表示能夠很容易地從一個領域轉移到另外一個領域,它們具備細粒度和緊湊性,這多是因爲精度和擴充次數之間的相關性。在關係推理中,根據做者的實驗,擴充的數量對對象簇的質量有着主要的影響[4]。自監督學習在許多方面都有很強的潛力成爲機器學習的將來。
參考文獻
[1] Carl Doersch et. al, Unsupervised Visual Representation Learning by Context Prediction, 2015.
[2] Mehdi Noroozi et. al, Unsupervised Learning of Visual Representations by Solving Jigsaw Puzzles, 2017.
[3] Zhang et. al, Colorful Image Colorization, 2016.
[4] Mehdi Noroozi et. al, Representation Learning by Learning to Count, 2017.
[5] Ting Chen et. al, A Simple Framework for Contrastive Learning of Visual Representations, 2020.
[6] Massimiliano Patacchiola et. al, Self-Supervised Relational Reasoning for Representation Learning, 2020.
[7] Adam Santoro et. al, Relational recurrent neural networks, 2018.