摘要: 照片承載了不少人在某個時刻的記憶,尤爲是一些老舊的黑白照片,塵封於腦海之中,隨着時間的流逝,記憶中對當時顏色的印象也會慢慢消散,這確實有些惋惜。技術的發展會解決一些現有的難題,深度學習剛好可以解決這個問題。
人工智能和深度學習技術逐漸在各行各業中發揮着做用,尤爲是在計算機視覺領域,深度學習就像繼承了某些上帝的功能,無所不能,使人歎爲觀止。照片承載了不少人在某個時刻的記憶,尤爲是一些老舊的黑白照片,塵封於腦海之中,隨着時間的流逝,記憶中對當時顏色的印象也會慢慢消散,這確實有些惋惜。但隨着科技的發展,這些已再也不是比較難的問題。在這篇文章中,將帶領你們領略一番深度學習的強大能力——將灰度圖像轉換爲彩色圖像。文章使用PyTorch從頭開始構建一個機器學習模型,自動將灰度圖像轉換爲彩色圖像,而且給出了相應代碼及圖像效果圖。整篇文章都是經過iPython Notebook中實現,對性能的要求不高,讀者們能夠自行動手實踐一下在各自的計算機上運行下,親身體驗下深度學習神奇的效果吧。算法
PS:不只可以對舊圖像進行着色,還能夠對視頻(每次對視頻進行一幀處理)進行着色哦!閒話少敘,下面直接進入正題吧。網絡
簡介dom
在圖像着色任務中,咱們的目標是在給定灰度輸入圖像的狀況下生成彩色圖像。這個問題是具備必定的挑戰性,由於它是多模式的——單個灰度圖像可能對應許多合理的彩色圖像。所以,傳統模型一般依賴於重要的用戶輸入以及輸入的灰度圖像內容。 機器學習
最近,深層神經網絡在自動圖像着色方面取得了顯着的成功——從灰度到彩色,無需額外的人工輸入。這種成功的部分緣由在於深層神經網絡可以捕捉和使用語義信息(即圖像的實際內容),儘管目前還不可以肯定這些類型的模型表現如此出色的緣由,由於深度學習相似於黑匣子,暫時沒法弄清算法是如何自動學習,後續會朝着可解釋性研究方向發展。 ide
在解釋模型以前,首先以更精確地方式闡述咱們所面臨的問題。函數
問題 工具
咱們的目的是要從灰度圖像中推斷出每一個像素(亮度、飽和度和色調)具備3個值的全色圖像,對於灰度圖而言,每一個像素僅具備1個值(僅亮度)。爲簡單起見,咱們只能處理大小爲256 x 256的圖像,因此咱們的輸入圖像大小爲256 x 256 x 1(亮度通道),輸出的圖像大小爲256 x 256 x 2(另兩個通道)。性能
正如人們一般所作的那樣,咱們不是用RGB格式的圖像進行處理,而是使用LAB色彩空間(亮度,A和B)。該色彩空間包含與RGB徹底相同的信息,但它將使咱們可以更容易地將亮度通道與其餘兩個(咱們稱之爲A和B)分開。在稍後會構造一個輔助函數來完成這個轉換過程。學習
此外將嘗試直接預測輸入圖像的顏色值(即迴歸)。還有其餘更有趣的分類方法,但目前堅持使用迴歸方法,由於它很簡單且效果很好。優化
數據
着色數據無處不在,這是因爲咱們能夠從任何一張彩色圖像中提取出灰度通道。對於本文項目,咱們將使用MIT地點數據集中的一個子集,該子數據集包含地點、景觀和建築物。
# Download and unzip (2.2GB) !wget http://data.csail.mit.edu/places/places205/testSetPlaces205_resize.tar.gz !tar -xzf testSetPlaces205_resize.tar.gz
# Move data into training and validation directories import os os.makedirs('images/train/class/', exist_ok=True) # 40,000 images os.makedirs('images/val/class/', exist_ok=True) # 1,000 images for i, file in enumerate(os.listdir('testSet_resize')): if i < 1000: # first 1000 will be val os.rename('testSet_resize/' + file, 'images/val/class/' + file) else: # others will be val os.rename('testSet_resize/' + file, 'images/train/class/' + file)
# Make sure the images are there from IPython.display import Image, display display(Image(filename='images/val/class/84b3ccd8209a4db1835988d28adfed4c.jpg'))
工具
本文使用PyTorch構建和訓練搭建的模型。此外,咱們還了使用torchvision工具,該工具在PyTorch中處理圖像和視頻時頗有用,以及使用了scikit-learn工具,用於在RGB和LAB顏色空間之間進行轉換。
# Download and import libraries !pip install torch torchvision matplotlib numpy scikit-image pillow==4.1.1
# For plotting import numpy as np import matplotlib.pyplot as plt %matplotlib inline # For conversion from skimage.color import lab2rgb, rgb2lab, rgb2gray from skimage import io # For everything import torch import torch.nn as nn import torch.nn.functional as F # For our model import torchvision.models as models from torchvision import datasets, transforms # For utilities import os, shutil, time
# Check if GPU is available use_gpu = torch.cuda.is_available()
模型
模型採用卷積神經網絡構建而成,與傳統的卷積神經網絡模型相似,首先應用一些卷積層從圖像中提取特徵,而後將反捲積層應用於高級(增長空間分辨率)特徵。
具體來講,模型採用的是遷移學習的方法,基礎是ResNet-18模型,ResNet-18網絡具備18層結構以及剩餘鏈接的圖像分類網絡層。咱們修改了該網絡的第一層,以便它接受灰度輸入而不是彩色輸入,而且切斷了第六層後面的網絡結構:
如今,在代碼中定義後續的網絡模型,將從網絡的後半部分開始,即上採樣層:
class ColorizationNet(nn.Module): def __init__(self, input_size=128): super(ColorizationNet, self).__init__() MIDLEVEL_FEATURE_SIZE = 128 ## First half: ResNet resnet = models.resnet18(num_classes=365) # Change first conv layer to accept single-channel (grayscale) input resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1)) # Extract midlevel features from ResNet-gray self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6]) ## Second half: Upsampling self.upsample = nn.Sequential( nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Upsample(scale_factor=2), nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1), nn.Upsample(scale_factor=2) ) def forward(self, input): # Pass input through ResNet-gray to extract features midlevel_features = self.midlevel_resnet(input) # Upsample to get colors output = self.upsample(midlevel_features) return output
如今經過下面的代碼建立整個模型:
model = ColorizationNet()
訓練
損失函數
因爲使用的是迴歸方法,因此使用的仍然是均方偏差損失函數:嘗試最小化預測的顏色值與真實(實際值)顏色值之間的平方距離。
criterion = nn.MSELoss()
因爲問題的多形式性,上述損失函數對於着色有一點小的問題。例如,若是一件灰色的衣服多是紅色或藍色,而模型若選擇錯誤的顏色時,則會受到嚴厲的懲罰。所以,構建的模型一般會選擇與飽和度鮮豔的顏色相比不太可能「很是錯誤」的不飽和顏色。關於這個問題已經有了重要的研究(參見Zhang等人),可是本文將堅持這種損失函數,就是這麼任性。
優化
使用Adam優化器優化選定的損失函數(標準)。
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0.0)
加載數據
使用torchtext來加載數據,因爲咱們須要LAB空間中的圖像,因此首先必須定義一個自定義數據加載器(dataloader)來轉換圖像。
class GrayscaleImageFolder(datasets.ImageFolder): '''Custom images folder, which converts images to grayscale before loading''' def __getitem__(self, index): path, target = self.imgs[index] img = self.loader(path) if self.transform is not None: img_original = self.transform(img) img_original = np.asarray(img_original) img_lab = rgb2lab(img_original) img_lab = (img_lab + 128) / 255 img_ab = img_lab[:, :, 1:3] img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float() img_original = rgb2gray(img_original) img_original = torch.from_numpy(img_original).unsqueeze(0).float() if self.target_transform is not None: target = self.target_transform(target) return img_original, img_ab, target
接下來,對訓練數據和驗證數據定義變換。
# Training train_transforms = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip()]) train_imagefolder = GrayscaleImageFolder('images/train', train_transforms) train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=64, shuffle=True) # Validation val_transforms = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)]) val_imagefolder = GrayscaleImageFolder('images/val' , val_transforms) val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=64, shuffle=False)
輔助函數
在進行訓練以前,定義了輔助函數來跟蹤訓練損失並將圖像轉換回RGB圖像。
class AverageMeter(object): '''A handy class from the PyTorch ImageNet tutorial''' def __init__(self): self.reset() def reset(self): self.val, self.avg, self.sum, self.count = 0, 0, 0, 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None): '''Show/save rgb image from grayscale and ab channels Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}''' plt.clf() # clear matplotlib color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels color_image = color_image.transpose((1, 2, 0)) # rescale for matplotlib color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100 color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128 color_image = lab2rgb(color_image.astype(np.float64)) grayscale_input = grayscale_input.squeeze().numpy() if save_path is not None and save_name is not None: plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray') plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))
驗證
在驗證過程當中,使用torch.no_grad()函數簡單地運行下沒有反向傳播的模型。
def validate(val_loader, model, criterion, save_images, epoch): model.eval() # Prepare value counters and timers batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter() end = time.time() already_saved_images = False for i, (input_gray, input_ab, target) in enumerate(val_loader): data_time.update(time.time() - end) # Use GPU if use_gpu: input_gray, input_ab, target = input_gray.cuda(), input_ab.cuda(), target.cuda() # Run model and record loss output_ab = model(input_gray) # throw away class predictions loss = criterion(output_ab, input_ab) losses.update(loss.item(), input_gray.size(0)) # Save images to file if save_images and not already_saved_images: already_saved_images = True for j in range(min(len(output_ab), 10)): # save at most 5 images save_path = {'grayscale': 'outputs/gray/', 'colorized': 'outputs/color/'} save_name = 'img-{}-epoch-{}.jpg'.format(i * val_loader.batch_size + j, epoch) to_rgb(input_gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name) # Record time to do forward passes and save images batch_time.update(time.time() - end) end = time.time() # Print model accuracy -- in the code below, val refers to both value and validation if i % 25 == 0: print('Validate: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( i, len(val_loader), batch_time=batch_time, loss=losses)) print('Finished validation.') return losses.avg
訓練
在訓練過程當中,使用loss.backward()運行模型並進行反向傳播過程。咱們首先定義了一個訓練一個epoch的函數:
def train(train_loader, model, criterion, optimizer, epoch): print('Starting training epoch {}'.format(epoch)) model.train() # Prepare value counters and timers batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter() end = time.time() for i, (input_gray, input_ab, target) in enumerate(train_loader): # Use GPU if available if use_gpu: input_gray, input_ab, target = input_gray.cuda(), input_ab.cuda(), target.cuda() # Record time to load data (above) data_time.update(time.time() - end) # Run forward pass output_ab = model(input_gray) loss = criterion(output_ab, input_ab) losses.update(loss.item(), input_gray.size(0)) # Compute gradient and optimize optimizer.zero_grad() loss.backward() optimizer.step() # Record time to do forward and backward passes batch_time.update(time.time() - end) end = time.time() # Print model accuracy -- in the code below, val refers to value, not validation if i % 25 == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses)) print('Finished training epoch {}'.format(epoch))
接下來,咱們定義一個循環訓練函數,即訓練100個epoch:
# Move model and loss function to GPU if use_gpu: criterion = criterion.cuda() model = model.cuda()
# Make folders and set parameters os.makedirs('outputs/color', exist_ok=True) os.makedirs('outputs/gray', exist_ok=True) os.makedirs('checkpoints', exist_ok=True) save_images = True best_losses = 1e10 epochs = 100
# Train model for epoch in range(epochs): # Train for one epoch, then validate train(train_loader, model, criterion, optimizer, epoch) with torch.no_grad(): losses = validate(val_loader, model, criterion, save_images, epoch) # Save checkpoint and replace old best model if current model is better if losses < best_losses: best_losses = losses torch.save(model.state_dict(), 'checkpoints/model-epoch-{}-losses-{:.3f}.pth'.format(epoch+1,losses))
Starting training epoch 0 ...
預訓練模型
若是你想運用預訓練模型而不想從頭開始訓練的話,我已經爲你訓練了好一個模型。該模型在少許時間內接受相對少許的數據訓練,而且可以工做正常。能夠從下面的連接下載並使用它:
# Download pretrained model !wget https://www.dropbox.com/s/kz76e7gv2ivmu8p/model-epoch-93.pth #https://www.dropbox.com/s/9j9rvaw2fo1osyj/model-epoch-67.pth
# Load model pretrained = torch.load('model-epoch-93.pth', map_location=lambda storage, loc: storage) model.load_state_dict(pretrained)
# Validate save_images = True with torch.no_grad(): validate(val_loader, model, criterion, save_images, 0)
Validate: [0/16] Time 10.628 (10.628) Loss 0.0030 (0.0030) Validate: [16/16] Time 0.328 ( 0.523) Loss 0.0029 (0.0029)
結果
有趣的內容到了,讓咱們看看深度學習技術實現的效果吧!
# Show images import matplotlib.image as mpimg image_pairs = [('outputs/color/img-2-epoch-0.jpg', 'outputs/gray/img-2-epoch-0.jpg'), ('outputs/color/img-7-epoch-0.jpg', 'outputs/gray/img-7-epoch-0.jpg')] for c, g in image_pairs: color = mpimg.imread(c) gray = mpimg.imread(g) f, axarr = plt.subplots(1, 2) f.set_size_inches(15, 15) axarr[0].imshow(gray, cmap='gray') axarr[1].imshow(color) axarr[0].axis('off'), axarr[1].axis('off') plt.show()
結論
在這篇文章中,使用PyTorch工具從頭建立了一個簡單的自動圖像着色器,沒有太複雜的代碼,只須要簡單的準備好數據並設計好合理的模型便可獲得使人使人興奮的結果,此外,這僅僅只是起步,後續還有不少地方能夠進行改進優化並進行推廣。
本文做者:【方向】
閱讀原文本文爲雲棲社區原創內容,未經容許不得轉載。