PaddleColorization-黑白照片着色python
將黑白照片着色是否是一件神奇的事情?數據庫
本項目將帶領你一步一步學習將黑白圖片甚至黑白影片的彩色化bash
下載安裝命令 ## CPU版本安裝命令 pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle ## GPU版本安裝命令 pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu
黑白照片着色
咱們都知道,有不少經典的老照片,受限於那個時代的技術,只能以黑白的形式傳世。儘管黑白照片別有一番風味,可是彩色照片有時候能給人更強的代入感。本項目經過通俗易懂的方式簡單實現黑白照片着色並對部分照片取得不錯的着色效果。黑白照片着色是計算機視覺領域經典的問題。近年來隨着卷積神經網絡(CNN)的普遍應用,經過CNN爲黑白照片着色成爲新穎且可行的方向。本項目承載於百度的學習與實訓社區AIStudio,總體實現採用ResNet殘差網絡爲主幹網絡並設計複合損失函數進行網絡訓練。網絡
開啓着色之旅!!!
先來看當作品
歡迎你們fork學習~有任何問題歡迎在評論區留言互相交流哦多線程
這裏一點小小的宣傳,我感興趣的領域包括遷移學習、生成對抗網絡。歡迎交流關注。來AI Studio互粉吧等你哦 app
#安裝所需的依賴庫 !pip install sklearn scikit-image
1 項目簡介
本項目基於paddlepaddle,結合殘差網絡(ResNet),經過監督學習的方式,訓練模型將黑白圖片轉換爲彩色圖片ide
1.1 殘差網絡(ResNet)
1.1.1 背景介紹
ResNet(Residual Network) [15] 是2015年ImageNet圖像分類、圖像物體定位和圖像物體檢測比賽的冠軍。針對隨着網絡訓練加深致使準確度降低的問題,ResNet提出了殘差學習方法來減輕訓練深層網絡的困難。在已有設計思路(BN, 小卷積核,全卷積網絡)的基礎上,引入了殘差模塊。每一個殘差模塊包含兩條路徑,其中一條路徑是輸入特徵的直連通路,另外一條路徑對該特徵作兩到三次卷積操做獲得該特徵的殘差,最後再將兩條路徑上的特徵相加。函數
殘差模塊如圖9所示,左邊是基本模塊鏈接方式,由兩個輸出通道數相同的3x3卷積組成。右邊是瓶頸模塊(Bottleneck)鏈接方式,之因此稱爲瓶頸,是由於上面的1x1卷積用來降維(圖示例即256->64),下面的1x1卷積用來升維(圖示例即64->256),這樣中間3x3卷積的輸入和輸出通道數都較小(圖示例即64->64)。工具
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-4gc3N2O9-1607403578575)(https://ai-studio-static-online.cdn.bcebos.com/7ede3132804549228b5c4a729d90e6b25821272dd9e74e41a95d3363f9e06c0e)]oop
1.2 項目設計思路及主要解決問題
- 設計思路:經過訓練網絡對大量樣本的學習獲得經驗分佈(例如天空永遠是藍色的,草永遠是綠色的),經過經驗分佈推得黑白圖像上各部分合理的顏色
- 主要解決問題:大量物體顏色並非固定的也就是物體顏色具備多模態性(例如:蘋果能夠是紅色也能夠是綠色和黃色)。一般使用均方差做爲損失函數會讓具備顏色多模態屬性的物體趨於尋找一個「平均」的顏色(一般爲淡黃色)致使着色後的圖片飽和度不高。
1.3 本文主要特徵
- 將Adam優化器beta1參數設置爲0.8,具體請參考原論文
- 將BatchNorm批歸一化中momentum參數設置爲0.5
- 採用基本模塊鏈接方式
- 爲抑制多模態問題,在均方差的基礎上從新設計損失函數
損失函數公式以下:
O u t = 1 / n ∑ ( i n p u t − l a b e l ) 2 + 16.7 / ( n ∑ ( i n p u t − i n p u t ˉ ) 2 ) Out = 1/n\sum{(input-label)^{2}} + 16.7/(n{\sum{(input - \bar{input})^{2}}}) Out=1/n∑(input−label)2+16.7/(n∑(input−inputˉ)2)
1.4 數據集介紹(ImageNet)
ImageNet項目是一個用於視覺對象識別軟件研究的大型可視化數據庫。超過1400萬的圖像URL被ImageNet手動註釋,以指示圖片中的對象;在至少一百萬個圖像中,還提供了邊界框。ImageNet包含2萬多個類別; [2]一個典型的類別,如「氣球」或「草莓」,包含數百個圖像。第三方圖像URL的註釋數據庫能夠直接從ImageNet免費得到;可是,實際的圖像不屬於ImageNet。自2010年以來,ImageNet項目每一年舉辦一次軟件比賽,即ImageNet大規模視覺識別挑戰賽(ILSVRC),軟件程序競相正確分類檢測物體和場景。 ImageNet挑戰使用了一個「修剪」的1000個非重疊類的列表。2012年在解決ImageNet挑戰方面取得了巨大的突破,被普遍認爲是2010年的深度學習革命的開始。(來源:百度百科)
ImageNet2012介紹:
- Training images (Task 1 & 2). 138GB.(約120萬張高清圖片,共1000個類別)
- Validation images (all tasks). 6.3GB.
- Training bounding box annotations (Task 1 & 2 only). 20MB.
1.5 LAB顏色空間
Lab模式是根據Commission International Eclairage(CIE)在1931年所制定的一種測定顏色的國際標準創建的。於1976年被改進,而且命名的一種色彩模式。Lab顏色模型彌補了RGB和CMYK兩種色彩模式的不足。它是一種設備無關的顏色模型,也是一種基於生理特徵的顏色模型。 [1] Lab顏色模型由三個要素組成,一個要素是亮度(L),a 和b是兩個顏色通道。a包括的顏色是從深綠色(低亮度值)到灰色(中亮度值)再到亮粉紅色(高亮度值);b是從亮藍色(低亮度值)到灰色(中亮度值)再到黃色(高亮度值)。所以,這種顏色混合後將產生具備明亮效果的色彩。(來源:百度百科)
2.使用Shell命令對數據集進行初步處理(運行時間:約20min)
tar xf data/data9244/ILSVRC2012_img_val.tar -C work/test/ cd ./work/train/;ls ../data/tar/*.tar | xargs -n1 tar xf #顯示work/train中圖片數量 find work/train -type f | wc -l
mkdir: cannot create directory ‘work/train’: File exists
mkdir: cannot create directory ‘work/test’: File exists
3.預處理
3.1預處理-採用多線程對訓練集中單通道圖刪除(運行時間:約20min)
import os import imghdr import numpy as np from PIL import Image import threading '''多線程將數據集中單通道圖刪除''' def cutArray(l, num): avg = len(l) / float(num) o = [] last = 0.0 while last < len(l): o.append(l[int(last):int(last + avg)]) last += avg return o def deleteErrorImage(path,image_dir): count = 0 for file in image_dir: try: image = os.path.join(path,file) image_type = imghdr.what(image) if image_type is not 'jpeg': os.remove(image) count = count + 1 img = np.array(Image.open(image)) if len(img.shape) is 2: os.remove(image) count = count + 1 except Exception as e: print(e) print('done!') print('已刪除數量:' + str(count)) class thread(threading.Thread): def __init__(self, threadID, path, files): threading.Thread.__init__(self) self.threadID = threadID self.path = path self.files = files def run(self): deleteErrorImage(self.path,self.files) if __name__ == '__main__': path = './work/train/' files = os.listdir(path) files = cutArray(files,8) t1 = threading.Thread(target=deleteErrorImage,args=(path,files[0])) t2 = threading.Thread(target=deleteErrorImage,args=(path,files[1])) t3 = threading.Thread(target=deleteErrorImage,args=(path,files[2])) t4 = threading.Thread(target=deleteErrorImage,args=(path,files[3])) t5 = threading.Thread(target=deleteErrorImage,args=(path,files[4])) t6 = threading.Thread(target=deleteErrorImage,args=(path,files[5])) t7 = threading.Thread(target=deleteErrorImage,args=(path,files[6])) t8 = threading.Thread(target=deleteErrorImage,args=(path,files[7])) threadList = [] threadList.append(t1) threadList.append(t2) threadList.append(t3) threadList.append(t4) threadList.append(t5) threadList.append(t6) threadList.append(t7) threadList.append(t8) for t in threadList: t.setDaemon(True) t.start() t.join()
done!
已刪除數量:470
done!
已刪除數量:432
done!
已刪除數量:426
done!
已刪除數量:483
[Errno 2] No such file or directory: ‘./work/train/n02105855_2933.JPEG’
done!
已刪除數量:490
done!
已刪除數量:454
done!
已刪除數量:467
done!
已刪除數量:482
3.2預處理-採用多線程對圖片進行縮放後裁切到256*256分辨率(運行時間:約40min)
from PIL import Image import os.path import os import threading from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True '''多線程將圖片縮放後再裁切到256*256分辨率''' w = 256 h = 256 def cutArray(l, num): avg = len(l) / float(num) o = [] last = 0.0 while last < len(l): o.append(l[int(last):int(last + avg)]) last += avg return o def convertjpg(jpgfile,outdir,width=w,height=h): img=Image.open(jpgfile) (l,h) = img.size rate = min(l,h) / width try: img = img.resize((int(l // rate),int(h // rate)),Image.BILINEAR) (l,h) = img.size lstart = (l - width)//2 hstart = (h - height)//2 img = img.crop((lstart,hstart,lstart + width,hstart + height)) img.save(os.path.join(outdir,os.path.basename(jpgfile))) except Exception as e: print(e) class thread(threading.Thread): def __init__(self, threadID, inpath, outpath, files): threading.Thread.__init__(self) self.threadID = threadID self.inpath = inpath self.outpath = outpath self.files = files def run(self): count = 0 try: for file in self.files: convertjpg(self.inpath + file,self.outpath) count = count + 1 except Exception as e: print(e) print('已處理圖片數量:' + str(count)) if __name__ == '__main__': inpath = './work/train/' outpath = './work/train/' files = os.listdir(inpath) # for file in files: # convertjpg(path + file,path) files = cutArray(files,8) T1 = thread(1, inpath, outpath, files[0]) T2 = thread(2, inpath, outpath, files[1]) T3 = thread(3, inpath, outpath, files[2]) T4 = thread(4, inpath, outpath, files[3]) T5 = thread(5, inpath, outpath, files[4]) T6 = thread(6, inpath, outpath, files[5]) T7 = thread(7, inpath, outpath, files[6]) T8 = thread(8, inpath, outpath, files[7]) T1.start() T2.start() T3.start() T4.start() T5.start() T6.start() T7.start() T8.start() T1.join() T2.join() T3.join() T4.join() T5.join() T6.join() T7.join() T8.join()
已處理圖片數量:58782
已處理圖片數量:58783
已處理圖片數量:58782
已處理圖片數量:58782
已處理圖片數量:58782
已處理圖片數量:58782
已處理圖片數量:58782
已處理圖片數量:58782
4.導入本項目所需的庫
import os import cv2 import numpy as np import paddle.dataset as dataset from skimage import io,color,transform import sklearn.neighbors as neighbors import paddle import paddle.fluid as fluid import numpy as np import sys import os from skimage import io,color import matplotlib.pyplot as plt import six
5.定義數據預處理工具-DataReader
'''準備數據,定義Reader()''' PATH = 'work/train/' TEST = 'work/train/' Q = np.load('work/Q.npy') Weight = np.load('work/Weight.npy') class DataGenerater: def __init__(self): self.datalist = os.listdir(PATH) self.testlist = os.listdir(TEST) self.datalist = datalist def load(self, image): '''讀取圖片,並轉爲Lab,並提取出L和ab''' img = io.imread(image) lab = np.array(color.rgb2lab(img)).transpose() l = lab[:1,:,:] l = l.astype('float32') ab = lab[1:,:,:] ab = ab.astype('float32') return l,ab def create_train_reader(self): '''給dataset定義reader''' def reader(): for img in self.datalist: #print(img) try: l, ab = self.load(PATH + img) #print(ab) yield l.astype('float32'), ab.astype('float32') except Exception as e: print(e) return reader def create_test_reader(self,): '''給test定義reader''' def reader(): for img in self.testlist: l,ab = self.load(TEST + img) yield l.astype('float32'),ab.astype('float32') return reader def train(batch_sizes = 32): reader = DataGenerater().create_train_reader() return reader def test(): reader = DataGenerater().create_test_reader() return reader
6.定義網絡功能模塊並定義網絡
本文網絡設計採用3組基本殘差模塊和2組反捲積層組成
import IPython.display as display import warnings warnings.filterwarnings('ignore') Q = np.load('work/Q.npy') weight = np.load('work/Weight.npy') Params_dirname = "work/model/gray2color.inference.model" '''自定義損失函數''' def createLoss(predict, truth): '''均方差''' loss1 = fluid.layers.square_error_cost(predict,truth) #loss2 = fluid.layers.square_error_cost(predict,fluid.layers.fill_constant(shape=[BATCH_SIZE,2,256,256],value=fluid.layers.mean(predict),dtype='float32')) cost = fluid.layers.mean(loss1) #+ 16.7 / fluid.layers.mean(loss2) return cost def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu', bias_attr=True): tmp = fluid.layers.conv2d( input=input, filter_size=filter_size, num_filters=ch_out, stride=stride, padding=padding, act=None, bias_attr=bias_attr) return fluid.layers.batch_norm(input=tmp,act=act,momentum=0.5) def shortcut(input, ch_in, ch_out, stride): if ch_in != ch_out: return conv_bn_layer(input, ch_out, 1, stride, 0, None) else: return input def basicblock(input, ch_in, ch_out, stride): tmp = conv_bn_layer(input, ch_out, 3, stride, 1) tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True) short = shortcut(input, ch_in, ch_out, stride) return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') def layer_warp(block_func, input, ch_in, ch_out, count, stride): tmp = block_func(input, ch_in, ch_out, stride) for i in range(1, count): tmp = block_func(tmp, ch_out, ch_out, 1) return tmp ###反捲積層 def deconv(x, num_filters, filter_size=5, stride=2, dilation=1, padding=2, output_size=None, act=None): return fluid.layers.conv2d_transpose( input=x, num_filters=num_filters, # 濾波器數量 output_size=output_size, # 輸出圖片大小 filter_size=filter_size, # 濾波器大小 stride=stride, # 步長 dilation=dilation, # 膨脹比例大小 padding=padding, use_cudnn=True, # 是否使用cudnn內核 act=act # 激活函數 ) def bn(x, name=None, act=None,momentum=0.5): return fluid.layers.batch_norm( x, bias_attr=None, # 指定偏置的屬性的對象 moving_mean_name=name + '3', # moving_mean的名稱 moving_variance_name=name + '4', # moving_variance的名稱 name=name, act=act, momentum=momentum, ) def resnetImagenet(input): #128 x = layer_warp(basicblock, input, 64, 128, 1, 2) #64 x = layer_warp(basicblock, x, 128, 256, 1, 2) #32 x = layer_warp(basicblock, x, 256, 512, 1, 2) #16 x = layer_warp(basicblock, x, 512, 1024, 1, 2) #8 x = layer_warp(basicblock, x, 1024, 2048, 1, 2) #16 x = deconv(x, num_filters=1024, filter_size=4, stride=2, padding=1) x = bn(x, name='bn_1', act='relu', momentum=0.5) #32 x = deconv(x, num_filters=512, filter_size=4, stride=2, padding=1) x = bn(x, name='bn_2', act='relu', momentum=0.5) #64 x = deconv(x, num_filters=256, filter_size=4, stride=2, padding=1) x = bn(x, name='bn_3', act='relu', momentum=0.5) #128 x = deconv(x, num_filters=128, filter_size=4, stride=2, padding=1) x = bn(x, name='bn_4', act='relu', momentum=0.5) #256 x = deconv(x, num_filters=64, filter_size=4, stride=2, padding=1) x = bn(x, name='bn_5', act='relu', momentum=0.5) x = deconv(x, num_filters=2, filter_size=3, stride=1, padding=1) return x
7.訓練網絡
設置的超參數爲:
- 學習率:2e-5
- Epoch:30
- Mini-Batch: 10
- 輸入Tensor:[-1,1,256,256]
預訓練的預測模型存放路徑work/model/gray2color.inference.model
BATCH_SIZE = 30 EPOCH_NUM = 300 def ResNettrain(): gray = fluid.layers.data(name='gray', shape=[1, 256,256], dtype='float32') truth = fluid.layers.data(name='truth', shape=[2, 256,256], dtype='float32') predict = resnetImagenet(gray) cost = createLoss(predict=predict,truth=truth) return predict,cost '''optimizer函數''' def optimizer_program(): return fluid.optimizer.Adam(learning_rate=2e-5,beta1=0.8) train_reader = paddle.batch(paddle.reader.shuffle( reader=train(), buf_size=7500*3 ),batch_size=BATCH_SIZE) test_reader = paddle.batch(reader=test(), batch_size=10) use_cuda = True if not use_cuda: os.environ['CPU_NUM'] = str(6) feed_order = ['gray', 'weight'] place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() main_program = fluid.default_main_program() star_program = fluid.default_startup_program() '''網絡訓練''' predict,cost = ResNettrain() '''優化函數''' optimizer = optimizer_program() optimizer.minimize(cost) exe = fluid.Executor(place) def train_loop(): gray = fluid.layers.data(name='gray', shape=[1, 256,256], dtype='float32') truth = fluid.layers.data(name='truth', shape=[2, 256,256], dtype='float32') feeder = fluid.DataFeeder( feed_list=['gray','truth'], place=place) exe.run(star_program) #增量訓練 fluid.io.load_persistables(exe, 'work/model/incremental/', main_program) for pass_id in range(EPOCH_NUM): step = 0 for data in train_reader(): loss = exe.run(main_program, feed=feeder.feed(data),fetch_list=[cost]) step += 1 if step % 1000 == 0: try: generated_img = exe.run(main_program, feed=feeder.feed(data),fetch_list=[predict]) plt.figure(figsize=(15,6)) plt.grid(False) for i in range(10): ab = generated_img[0][i] l = data[i][0][0] a = ab[0] b = ab[1] l = l[:, :, np.newaxis] a = a[:, :, np.newaxis].astype('float64') b = b[:, :, np.newaxis].astype('float64') lab = np.concatenate((l, a, b), axis=2) img = color.lab2rgb((lab)) img = transform.rotate(img, 270) img = np.fliplr(img) plt.grid(False) plt.subplot(2, 5, i + 1) plt.imshow(img) plt.axis('off') plt.xticks([]) plt.yticks([]) msg = 'Epoch ID={0} Batch ID={1} Loss={2}'.format(pass_id, step, loss[0][0]) plt.suptitle(msg,fontsize=20) plt.draw() plt.savefig('{}/{:04d}_{:04d}.png'.format('work/output_img', pass_id, step),bbox_inches='tight') plt.pause(0.01) display.clear_output(wait=True) except IOError: print(IOError) fluid.io.save_persistables(exe,'work/model/incremental/',main_program) fluid.io.save_inference_model(Params_dirname, ["gray"],[predict], exe) train_loop()
8.項目總結
經過按部就班的方式敘述了項目的過程。
對於訓練結果雖然本項目經過抑制平均化加大離散程度提升了着色的飽和度,但最終結果仍然有較大現實差距,只能對部分場景有比較好的結果,對人造場景(如超市景觀等)仍然表現力不足。
接下來準備進一步去設計損失函數,目的是讓網絡着色結果足以欺騙人的」直覺感覺「,而不是一味地接近真實場景
下載安裝命令 ## CPU版本安裝命令 pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle ## GPU版本安裝命令 pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu
本文同步分享在 博客「Redflashing」(CSDN)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。