飛槳PaddleColorization-黑白照片着色

PaddleColorization-黑白照片着色python

將黑白照片着色是否是一件神奇的事情?數據庫

本項目將帶領你一步一步學習將黑白圖片甚至黑白影片的彩色化bash

下載安裝命令

## CPU版本安裝命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安裝命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

 

黑白照片着色

咱們都知道,有不少經典的老照片,受限於那個時代的技術,只能以黑白的形式傳世。儘管黑白照片別有一番風味,可是彩色照片有時候能給人更強的代入感。本項目經過通俗易懂的方式簡單實現黑白照片着色並對部分照片取得不錯的着色效果。黑白照片着色是計算機視覺領域經典的問題。近年來隨着卷積神經網絡(CNN)的普遍應用,經過CNN爲黑白照片着色成爲新穎且可行的方向。本項目承載於百度的學習與實訓社區AIStudio,總體實現採用ResNet殘差網絡爲主幹網絡並設計複合損失函數進行網絡訓練。網絡

開啓着色之旅!!!

先來看當作品

歡迎你們fork學習~有任何問題歡迎在評論區留言互相交流哦多線程

這裏一點小小的宣傳,我感興趣的領域包括遷移學習、生成對抗網絡。歡迎交流關注。來AI Studio互粉吧等你哦 app

#安裝所需的依賴庫 
!pip install sklearn scikit-image

1 項目簡介

本項目基於paddlepaddle,結合殘差網絡(ResNet),經過監督學習的方式,訓練模型將黑白圖片轉換爲彩色圖片ide


1.1 殘差網絡(ResNet)

1.1.1 背景介紹

ResNet(Residual Network) [15] 是2015年ImageNet圖像分類、圖像物體定位和圖像物體檢測比賽的冠軍。針對隨着網絡訓練加深致使準確度降低的問題,ResNet提出了殘差學習方法來減輕訓練深層網絡的困難。在已有設計思路(BN, 小卷積核,全卷積網絡)的基礎上,引入了殘差模塊。每一個殘差模塊包含兩條路徑,其中一條路徑是輸入特徵的直連通路,另外一條路徑對該特徵作兩到三次卷積操做獲得該特徵的殘差,最後再將兩條路徑上的特徵相加。函數

殘差模塊如圖9所示,左邊是基本模塊鏈接方式,由兩個輸出通道數相同的3x3卷積組成。右邊是瓶頸模塊(Bottleneck)鏈接方式,之因此稱爲瓶頸,是由於上面的1x1卷積用來降維(圖示例即256->64),下面的1x1卷積用來升維(圖示例即64->256),這樣中間3x3卷積的輸入和輸出通道數都較小(圖示例即64->64)。工具

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-4gc3N2O9-1607403578575)(https://ai-studio-static-online.cdn.bcebos.com/7ede3132804549228b5c4a729d90e6b25821272dd9e74e41a95d3363f9e06c0e)]oop

1.2 項目設計思路及主要解決問題

  • 設計思路:經過訓練網絡對大量樣本的學習獲得經驗分佈(例如天空永遠是藍色的,草永遠是綠色的),經過經驗分佈推得黑白圖像上各部分合理的顏色
  • 主要解決問題:大量物體顏色並非固定的也就是物體顏色具備多模態性(例如:蘋果能夠是紅色也能夠是綠色和黃色)。一般使用均方差做爲損失函數會讓具備顏色多模態屬性的物體趨於尋找一個「平均」的顏色(一般爲淡黃色)致使着色後的圖片飽和度不高。

1.3 本文主要特徵

  • 將Adam優化器beta1參數設置爲0.8,具體請參考原論文
  • 將BatchNorm批歸一化中momentum參數設置爲0.5
  • 採用基本模塊鏈接方式
  • 爲抑制多模態問題,在均方差的基礎上從新設計損失函數

損失函數公式以下:

O u t = 1 / n ∑ ( i n p u t − l a b e l ) 2 + 16.7 / ( n ∑ ( i n p u t − i n p u t ˉ ) 2 ) Out = 1/n\sum{(input-label)^{2}} + 16.7/(n{\sum{(input - \bar{input})^{2}}}) Out=1/n(inputlabel)2+16.7/(n(inputinputˉ)2)


1.4 數據集介紹(ImageNet)

ImageNet項目是一個用於視覺對象識別軟件研究的大型可視化數據庫。超過1400萬的圖像URL被ImageNet手動註釋,以指示圖片中的對象;在至少一百萬個圖像中,還提供了邊界框。ImageNet包含2萬多個類別; [2]一個典型的類別,如「氣球」或「草莓」,包含數百個圖像。第三方圖像URL的註釋數據庫能夠直接從ImageNet免費得到;可是,實際的圖像不屬於ImageNet。自2010年以來,ImageNet項目每一年舉辦一次軟件比賽,即ImageNet大規模視覺識別挑戰賽(ILSVRC),軟件程序競相正確分類檢測物體和場景。 ImageNet挑戰使用了一個「修剪」的1000個非重疊類的列表。2012年在解決ImageNet挑戰方面取得了巨大的突破,被普遍認爲是2010年的深度學習革命的開始。(來源:百度百科)

ImageNet2012介紹:

  • Training images (Task 1 & 2). 138GB.(約120萬張高清圖片,共1000個類別)
  • Validation images (all tasks). 6.3GB.
  • Training bounding box annotations (Task 1 & 2 only). 20MB.

1.5 LAB顏色空間

Lab模式是根據Commission International Eclairage(CIE)在1931年所制定的一種測定顏色的國際標準創建的。於1976年被改進,而且命名的一種色彩模式。Lab顏色模型彌補了RGB和CMYK兩種色彩模式的不足。它是一種設備無關的顏色模型,也是一種基於生理特徵的顏色模型。 [1] Lab顏色模型由三個要素組成,一個要素是亮度(L),a 和b是兩個顏色通道。a包括的顏色是從深綠色(低亮度值)到灰色(中亮度值)再到亮粉紅色(高亮度值);b是從亮藍色(低亮度值)到灰色(中亮度值)再到黃色(高亮度值)。所以,這種顏色混合後將產生具備明亮效果的色彩。(來源:百度百科)

2.使用Shell命令對數據集進行初步處理(運行時間:約20min)

tar xf data/data9244/ILSVRC2012_img_val.tar -C work/test/ 
cd ./work/train/;ls ../data/tar/*.tar | xargs -n1 tar xf 
#顯示work/train中圖片數量 
find work/train -type f | wc -l

mkdir: cannot create directory ‘work/train’: File exists
mkdir: cannot create directory ‘work/test’: File exists

3.預處理

3.1預處理-採用多線程對訓練集中單通道圖刪除(運行時間:約20min)

import os 
import imghdr 
import numpy as np 
from PIL import Image 
import threading 

'''多線程將數據集中單通道圖刪除''' 
def cutArray(l, num): 
avg = len(l) / float(num) 
o = [] 
last = 0.0 

while last < len(l): 
o.append(l[int(last):int(last + avg)]) 
last += avg 

return o 

def deleteErrorImage(path,image_dir): 
count = 0 
for file in image_dir: 
try: 
image = os.path.join(path,file) 
image_type = imghdr.what(image) 
if image_type is not 'jpeg': 
os.remove(image) 
count = count + 1 

img = np.array(Image.open(image)) 
if len(img.shape) is 2: 
os.remove(image) 
count = count + 1 
except Exception as e: 
print(e) 
print('done!') 
print('已刪除數量:' + str(count)) 

class thread(threading.Thread): 
def __init__(self, threadID, path, files): 
threading.Thread.__init__(self) 
self.threadID = threadID 
self.path = path 
self.files = files 
def run(self): 
deleteErrorImage(self.path,self.files) 

if __name__ == '__main__': 
path = './work/train/' 
files = os.listdir(path) 
files = cutArray(files,8) 
t1 = threading.Thread(target=deleteErrorImage,args=(path,files[0])) 
t2 = threading.Thread(target=deleteErrorImage,args=(path,files[1])) 
t3 = threading.Thread(target=deleteErrorImage,args=(path,files[2])) 
t4 = threading.Thread(target=deleteErrorImage,args=(path,files[3])) 
t5 = threading.Thread(target=deleteErrorImage,args=(path,files[4])) 
t6 = threading.Thread(target=deleteErrorImage,args=(path,files[5])) 
t7 = threading.Thread(target=deleteErrorImage,args=(path,files[6])) 
t8 = threading.Thread(target=deleteErrorImage,args=(path,files[7])) 
threadList = [] 
threadList.append(t1) 
threadList.append(t2) 
threadList.append(t3) 
threadList.append(t4) 
threadList.append(t5) 
threadList.append(t6) 
threadList.append(t7) 
threadList.append(t8) 
for t in threadList: 
t.setDaemon(True) 
t.start() 
t.join()

done!
已刪除數量:470
done!
已刪除數量:432
done!
已刪除數量:426
done!
已刪除數量:483
[Errno 2] No such file or directory: ‘./work/train/n02105855_2933.JPEG’
done!
已刪除數量:490
done!
已刪除數量:454
done!
已刪除數量:467
done!
已刪除數量:482

3.2預處理-採用多線程對圖片進行縮放後裁切到256*256分辨率(運行時間:約40min)

from PIL import Image 
import os.path 
import os 
import threading 
from PIL import ImageFile 
ImageFile.LOAD_TRUNCATED_IMAGES = True 

'''多線程將圖片縮放後再裁切到256*256分辨率''' 
w = 256 
h = 256 

def cutArray(l, num): 
avg = len(l) / float(num) 
o = [] 
last = 0.0 

while last < len(l): 
o.append(l[int(last):int(last + avg)]) 
last += avg 

return o 

def convertjpg(jpgfile,outdir,width=w,height=h): 
img=Image.open(jpgfile) 
(l,h) = img.size 
rate = min(l,h) / width 
try: 
img = img.resize((int(l // rate),int(h // rate)),Image.BILINEAR) 
(l,h) = img.size 
lstart = (l - width)//2 
hstart = (h - height)//2 
img = img.crop((lstart,hstart,lstart + width,hstart + height)) 
img.save(os.path.join(outdir,os.path.basename(jpgfile))) 
except Exception as e: 
print(e) 

class thread(threading.Thread): 
def __init__(self, threadID, inpath, outpath, files): 
threading.Thread.__init__(self) 
self.threadID = threadID 
self.inpath = inpath 
self.outpath = outpath 
self.files = files 
def run(self): 
count = 0 
try: 
for file in self.files: 
convertjpg(self.inpath + file,self.outpath) 
count = count + 1 
except Exception as e: 
print(e) 
print('已處理圖片數量:' + str(count)) 

if __name__ == '__main__': 
inpath = './work/train/' 
outpath = './work/train/' 
files = os.listdir(inpath) 
# for file in files: 
# convertjpg(path + file,path) 
files = cutArray(files,8) 
T1 = thread(1, inpath, outpath, files[0]) 
T2 = thread(2, inpath, outpath, files[1]) 
T3 = thread(3, inpath, outpath, files[2]) 
T4 = thread(4, inpath, outpath, files[3]) 
T5 = thread(5, inpath, outpath, files[4]) 
T6 = thread(6, inpath, outpath, files[5]) 
T7 = thread(7, inpath, outpath, files[6]) 
T8 = thread(8, inpath, outpath, files[7]) 

T1.start() 
T2.start() 
T3.start() 
T4.start() 
T5.start() 
T6.start() 
T7.start() 
T8.start() 

T1.join() 
T2.join() 
T3.join() 
T4.join() 
T5.join() 
T6.join() 
T7.join() 
T8.join()

 

已處理圖片數量:58782
已處理圖片數量:58783
已處理圖片數量:58782
已處理圖片數量:58782
已處理圖片數量:58782
已處理圖片數量:58782
已處理圖片數量:58782
已處理圖片數量:58782

4.導入本項目所需的庫

import os 
import cv2 
import numpy as np 
import paddle.dataset as dataset 
from skimage import io,color,transform 
import sklearn.neighbors as neighbors 
import paddle 
import paddle.fluid as fluid 
import numpy as np 
import sys 
import os 
from skimage import io,color 
import matplotlib.pyplot as plt 
import six

5.定義數據預處理工具-DataReader

'''準備數據,定義Reader()''' 

PATH = 'work/train/' 
TEST = 'work/train/' 
Q = np.load('work/Q.npy') 
Weight = np.load('work/Weight.npy') 

class DataGenerater: 
def __init__(self): 
self.datalist = os.listdir(PATH) 
self.testlist = os.listdir(TEST) 
self.datalist = datalist 


def load(self, image): 
'''讀取圖片,並轉爲Lab,並提取出L和ab''' 
img = io.imread(image) 
lab = np.array(color.rgb2lab(img)).transpose() 
l = lab[:1,:,:] 
l = l.astype('float32') 
ab = lab[1:,:,:] 
ab = ab.astype('float32') 
return l,ab 

def create_train_reader(self): 
'''給dataset定義reader''' 

def reader(): 
for img in self.datalist: 
#print(img) 
try: 
l, ab = self.load(PATH + img) 
#print(ab) 
yield l.astype('float32'), ab.astype('float32') 
except Exception as e: 
print(e) 

return reader 

def create_test_reader(self,): 
'''給test定義reader''' 

def reader(): 
for img in self.testlist: 
l,ab = self.load(TEST + img) 
yield l.astype('float32'),ab.astype('float32') 

return reader 
def train(batch_sizes = 32): 
reader = DataGenerater().create_train_reader() 
return reader 

def test(): 
reader = DataGenerater().create_test_reader() 
return reader

6.定義網絡功能模塊並定義網絡

本文網絡設計採用3組基本殘差模塊和2組反捲積層組成

import IPython.display as display 
import warnings 
warnings.filterwarnings('ignore') 

Q = np.load('work/Q.npy') 
weight = np.load('work/Weight.npy') 
Params_dirname = "work/model/gray2color.inference.model" 

'''自定義損失函數''' 
def createLoss(predict, truth): 
'''均方差''' 
loss1 = fluid.layers.square_error_cost(predict,truth) 
#loss2 = fluid.layers.square_error_cost(predict,fluid.layers.fill_constant(shape=[BATCH_SIZE,2,256,256],value=fluid.layers.mean(predict),dtype='float32')) 
cost = fluid.layers.mean(loss1) #+ 16.7 / fluid.layers.mean(loss2) 
return cost 

def conv_bn_layer(input, 
ch_out, 
filter_size, 
stride, 
padding, 
act='relu', 
bias_attr=True): 
tmp = fluid.layers.conv2d( 
input=input, 
filter_size=filter_size, 
num_filters=ch_out, 
stride=stride, 
padding=padding, 
act=None, 
bias_attr=bias_attr) 
return fluid.layers.batch_norm(input=tmp,act=act,momentum=0.5) 


def shortcut(input, ch_in, ch_out, stride): 
if ch_in != ch_out: 
return conv_bn_layer(input, ch_out, 1, stride, 0, None) 
else: 
return input 


def basicblock(input, ch_in, ch_out, stride): 
tmp = conv_bn_layer(input, ch_out, 3, stride, 1) 
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True) 
short = shortcut(input, ch_in, ch_out, stride) 
return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') 


def layer_warp(block_func, input, ch_in, ch_out, count, stride): 
tmp = block_func(input, ch_in, ch_out, stride) 
for i in range(1, count): 
tmp = block_func(tmp, ch_out, ch_out, 1) 
return tmp 

###反捲積層 
def deconv(x, num_filters, filter_size=5, stride=2, dilation=1, padding=2, output_size=None, act=None): 
return fluid.layers.conv2d_transpose( 
input=x, 
num_filters=num_filters, 
# 濾波器數量 
output_size=output_size, 
# 輸出圖片大小 
filter_size=filter_size, 
# 濾波器大小 
stride=stride, 
# 步長 
dilation=dilation, 
# 膨脹比例大小 
padding=padding, 
use_cudnn=True, 
# 是否使用cudnn內核 
act=act 
# 激活函數 
) 
def bn(x, name=None, act=None,momentum=0.5): 
return fluid.layers.batch_norm( 
x, 
bias_attr=None, 
# 指定偏置的屬性的對象 
moving_mean_name=name + '3', 
# moving_mean的名稱 
moving_variance_name=name + '4', 
# moving_variance的名稱 
name=name, 
act=act, 
momentum=momentum, 
) 


def resnetImagenet(input): 
#128 
x = layer_warp(basicblock, input, 64, 128, 1, 2) 
#64 
x = layer_warp(basicblock, x, 128, 256, 1, 2) 
#32 
x = layer_warp(basicblock, x, 256, 512, 1, 2) 
#16 
x = layer_warp(basicblock, x, 512, 1024, 1, 2) 
#8 
x = layer_warp(basicblock, x, 1024, 2048, 1, 2) 
#16 
x = deconv(x, num_filters=1024, filter_size=4, stride=2, padding=1) 
x = bn(x, name='bn_1', act='relu', momentum=0.5) 
#32 
x = deconv(x, num_filters=512, filter_size=4, stride=2, padding=1) 
x = bn(x, name='bn_2', act='relu', momentum=0.5) 
#64 
x = deconv(x, num_filters=256, filter_size=4, stride=2, padding=1) 
x = bn(x, name='bn_3', act='relu', momentum=0.5) 
#128 
x = deconv(x, num_filters=128, filter_size=4, stride=2, padding=1) 
x = bn(x, name='bn_4', act='relu', momentum=0.5) 
#256 
x = deconv(x, num_filters=64, filter_size=4, stride=2, padding=1) 
x = bn(x, name='bn_5', act='relu', momentum=0.5) 

x = deconv(x, num_filters=2, filter_size=3, stride=1, padding=1) 
return x

7.訓練網絡

設置的超參數爲:

  • 學習率:2e-5
  • Epoch:30
  • Mini-Batch: 10
  • 輸入Tensor:[-1,1,256,256]

預訓練的預測模型存放路徑work/model/gray2color.inference.model

BATCH_SIZE = 30 
EPOCH_NUM = 300 

def ResNettrain(): 
gray = fluid.layers.data(name='gray', shape=[1, 256,256], dtype='float32') 
truth = fluid.layers.data(name='truth', shape=[2, 256,256], dtype='float32') 
predict = resnetImagenet(gray) 
cost = createLoss(predict=predict,truth=truth) 
return predict,cost 


'''optimizer函數''' 
def optimizer_program(): 
return fluid.optimizer.Adam(learning_rate=2e-5,beta1=0.8) 


train_reader = paddle.batch(paddle.reader.shuffle( 
reader=train(), buf_size=7500*3 
),batch_size=BATCH_SIZE) 
test_reader = paddle.batch(reader=test(), batch_size=10) 

use_cuda = True 
if not use_cuda: 
os.environ['CPU_NUM'] = str(6) 
feed_order = ['gray', 'weight'] 
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() 

main_program = fluid.default_main_program() 
star_program = fluid.default_startup_program() 

'''網絡訓練''' 
predict,cost = ResNettrain() 

'''優化函數''' 
optimizer = optimizer_program() 
optimizer.minimize(cost) 

exe = fluid.Executor(place) 

def train_loop(): 
gray = fluid.layers.data(name='gray', shape=[1, 256,256], dtype='float32') 
truth = fluid.layers.data(name='truth', shape=[2, 256,256], dtype='float32') 
feeder = fluid.DataFeeder( 
feed_list=['gray','truth'], place=place) 
exe.run(star_program) 

#增量訓練 
fluid.io.load_persistables(exe, 'work/model/incremental/', main_program) 

for pass_id in range(EPOCH_NUM): 
step = 0 
for data in train_reader(): 
loss = exe.run(main_program, feed=feeder.feed(data),fetch_list=[cost]) 
step += 1 
if step % 1000 == 0: 
try: 
generated_img = exe.run(main_program, feed=feeder.feed(data),fetch_list=[predict]) 
plt.figure(figsize=(15,6)) 
plt.grid(False) 
for i in range(10): 
ab = generated_img[0][i] 
l = data[i][0][0] 
a = ab[0] 
b = ab[1] 
l = l[:, :, np.newaxis] 
a = a[:, :, np.newaxis].astype('float64') 
b = b[:, :, np.newaxis].astype('float64') 
lab = np.concatenate((l, a, b), axis=2) 
img = color.lab2rgb((lab)) 
img = transform.rotate(img, 270) 
img = np.fliplr(img) 
plt.grid(False) 
plt.subplot(2, 5, i + 1) 
plt.imshow(img) 
plt.axis('off') 
plt.xticks([]) 
plt.yticks([]) 
msg = 'Epoch ID={0} Batch ID={1} Loss={2}'.format(pass_id, step, loss[0][0]) 
plt.suptitle(msg,fontsize=20) 
plt.draw() 
plt.savefig('{}/{:04d}_{:04d}.png'.format('work/output_img', pass_id, step),bbox_inches='tight') 
plt.pause(0.01) 
display.clear_output(wait=True) 
except IOError: 
print(IOError) 

fluid.io.save_persistables(exe,'work/model/incremental/',main_program) 
fluid.io.save_inference_model(Params_dirname, ["gray"],[predict], exe) 
train_loop()

8.項目總結

經過按部就班的方式敘述了項目的過程。
對於訓練結果雖然本項目經過抑制平均化加大離散程度提升了着色的飽和度,但最終結果仍然有較大現實差距,只能對部分場景有比較好的結果,對人造場景(如超市景觀等)仍然表現力不足。
接下來準備進一步去設計損失函數,目的是讓網絡着色結果足以欺騙人的」直覺感覺「,而不是一味地接近真實場景

下載安裝命令

## CPU版本安裝命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安裝命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

本文同步分享在 博客「Redflashing」(CSDN)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。

相關文章
相關標籤/搜索