基於飛槳PaddlePaddle實現的Sub-Pixel圖像超分辨率python
1.項目介紹
本文則參考論文:Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network ,使用飛槳最新的分支版本,實現了一個輕量級圖像的超分辨率模型,旨在帶領各位小夥伴快速瞭解飛槳框架2.0,也能夠在此基礎上修改、優化模型,實現本身的超分辨率算法。git
飛槳PaddlePaddle最近迎來了重大更新,進入了2. 0時代。AI Studio也同步上線了最新版本得在線編程環境,又送免費GPU算力,這波羊毛不薅都對不起本身啊(手動狗頭)。飛槳框架2.0新添加了許多經常使用的API,豐富的API接口給開發帶來了便利,可以比較輕鬆的完成模型搭建及訓練。若是小夥伴們對本項目感興趣,歡迎來AI Studio Fork 運行嘗試。github
下載安裝命令 ## CPU版本安裝命令 pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle ## GPU版本安裝命令 pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu
AI Studio項目地址:
https://aistudio.baidu.com/aistudio/projectdetail/1109418算法
2.前言
圖像和視頻一般包含着大量的視覺信息,且視覺信息自己具備直觀高效的描述能力,因此隨着信息技術的高速發展,圖像和視頻的應用逐漸遍及人類社會的各個領域。近些年來,在計算機圖像處理,計算機視覺和機器學習等領域中,來自工業界和學術界的許多學者和專家都持續關注着視頻圖像的超分辨率技術這個基礎熱點問題。編程
圖像超分辨率的英文名稱是 Image Super Resolution。它指的是從低分辨率圖像中恢復高分辨率圖像的過程。
這項技術在現實世界中有普遍的應用,最多見的應用場景就是圖片的壓縮傳輸:爲了在同等帶寬下得到更高的圖像質量,超分辨率算法適用於低帶寬時低質量圖像上的加強。除了提高圖像感知的品質,也有助於提高其餘計算機視覺任務,例如遙感領域、醫學成像領域。傳統的超分辨率方法有:基於預測的方法、基於邊緣的方法、基於統計的方法、基於修補的方法、以及稀疏表示方法等。網絡
近些年深度學習技術的快速發展,使得基於深度學習的超分辨率模型性能優異,大量深度學習方法被應用於解決超分辨率任務,早期的表明做有SRCNN和SRGAN,近期CVPR2020上也有很多相關的論文,例如:DRN和USRNet。總的來講,深度學習超分辨率算法之間各不相同,主要是因爲下面幾個主要的方向:不一樣類型的網絡結構、不一樣類型的損失函數、不一樣類型的學習原則和策略等。架構
SRCNN:
SRGAN:
DRN:
USRNet:
app
3.項目背景
3.1摘要:
近年來,基於深度神經網絡的單圖像超分辨率重建模型在重建精度和計算性能方面都有了很大的進展。可是這些算法都太複雜了,效率很低。在本文中,咱們提出了一種新的CNN架構,能夠有效地下降計算的複雜度。在公開數據集上的評估結果代表,該方法的性能明顯優於以前基於CNN的方法(圖像爲+0.15dB),而且比其餘基於CNN的方法快了一個數量級。框架
3.2 網絡結構:
與以往的工做不一樣,此項目在網絡的末端纔將分辨率從LR提升到HR,並從LR特徵圖中超分辨率地解析HR數據。這樣就不須要在更大的HR分辨率下執行大部分超分辨率SR操做。爲此,咱們提出了一種有效的亞像素卷積層來學習圖像和視頻超分辨率的上尺度運算。這樣作有兩個優勢:
每一個LR圖像被直接送入網絡,經過LR空間中的非線性卷積進行特徵提取。因爲輸入分辨率下降,咱們能夠有效地使用較小的過濾器大小來整合相同的信息,同時保持給定的上下文區域。分辨率和濾波器尺寸的減少,大大下降了計算量和內存的開銷,可是足以實時實現超分辨率。
對於一個有圖層的網絡,咱們學習了特徵映射的上尺度過濾器,而不是輸入圖像的一個上尺度過濾器。此外,不使用顯式插值濾波器意味着網絡隱式地學習SR所需的處理。所以,與在第一層向上擴展單個固定濾波器相比,網絡可以學習更好和更復雜的LR到HR映射,這使得模型重建精度的有額外提升。dom
3.3 基於Paddle的代碼:
這裏僅展現了部分關鍵代碼,詳細實現請參考AI Studio項目:
https://aistudio.baidu.com/aistudio/projectdetail/1109418
3.3.1 數據預處理
飛槳框架2.0 爲咱們封裝好了Dataset類,咱們定義數據讀取器類時只須要繼承自它並實現__getitem__返回讀取的內容和__len__方法返回數據的樣本數。這裏,咱們須要數據讀取器返回一張縮小後的圖片和一張沒有縮放的圖片,這兩張圖片都只有Ycbcr通道中的Y通道,由於大量的研究表代表人眼對亮度更敏感,因此咱們這裏只對亮度通道Y進行採樣。
class BSD_data(Dataset): def __init__(self, mode='train', image_path="data/data55873/images/" ): super(BSD_data, self).__init__() self.mode = mode.lower() if self.mode == 'train': self.image_path = os.path.join(image_path,'train') elif self.mode == 'val': self.image_path = os.path.join(image_path,'val') else: raise ValueError('mode must be "train" or "val"') # 原始圖像的縮放大小 self.crop_size = 300 # 縮放倍率 self.upscale_factor = 3 # 縮小後送入神經網絡的大小 self.input_size = self.crop_size // self.upscale_factor # numpy隨機數種子 self.seed=1337 # 圖片集合 self.temp_images = [] # 加載數據 self._parse_dataset() def transforms(self, img): """ 圖像預處理工具,用於將升維(100, 100) => (100, 100,1), 並進行維度轉換 H W C => C H W """ if len(img.shape) == 2: img = np.expand_dims(img, axis=2) return img.transpose((2, 0, 1)) def __getitem__(self, idx): """ 返回 縮小3倍後的圖片 和 原始圖片 """ # 加載原始圖像 img = self._load_img(self.temp_images[idx]) # 將原始圖像縮放到(3, 300, 300) img = img.resize( [self.crop_size,self.crop_size], Image.BICUBIC ) #轉換爲YCbCr圖像 ycbcr = img.convert("YCbCr") # 由於人眼對亮度敏感,因此只取Y通道 y, cb, cr = ycbcr.split() y = np.asarray(y,dtype='float32') y = y / 255.0 # 縮放後的圖像和前面採起同樣的操做 img_ = img.resize( [self.input_size,self.input_size], Image.BICUBIC ) ycbcr_ = img_.convert("YCbCr") y_, cb_, cr_ = ycbcr_.split() y_ = np.asarray(y_,dtype='float32') y_ = y_ / 255.0 # 升維並將HWC轉換爲CHW img_s = self.transforms(y) img_l = self.transforms(y_) # img_s 爲縮小3倍後的圖片(1, 100, 100) # img_l 是原始圖片(1, 300, 300) return img_s , img_l def __len__(self): """ 實現__len__方法,返回數據集總數目 """ return len(self.temp_images) def _sort_images(self, img_dir): """ 對文件夾內的圖像進行按照文件名排序 """ files = [] for item in os.listdir(img_dir): if item.split('.')[-1].lower() in ["jpg",'jpeg','png']: files.append(os.path.join(img_dir, item)) return sorted(files) def _parse_dataset(self): """ 處理數據集 """ self.temp_images = self._sort_images(self.image_path) random.Random(self.seed).shuffle(self.temp_images) def _load_img(self, path): """ 從磁盤讀取圖片 """ with open(path, 'rb') as f: img = Image.open(io.BytesIO(f.read())) img = img.convert('RGB') return img
3.3.2 定義網絡結構:
經過2.2節網絡結構圖,能夠很容易的看出來:圖片通過三層CNN採樣後獲得R的平方個特徵通道,再經過Sub-Pixel層還原成channel個通道(這裏是1通道)圖像。
from paddle.nn import Layer, Conv2D class Sub_Pixel_CNN(Layer): def __init__(self, upscale_factor=3, channels=1): super(Sub_Pixel_CNN, self).__init__() self.conv1 = Conv2D(channels,64,5,stride=1, padding=2) self.conv2 = Conv2D(64,32,3,stride=1, padding=1) self.conv3 = Conv2D(32,channels * (upscale_factor ** 2),3,stride=1, padding=1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = paddle.fluid.layers.pixel_shuffle(x,3) return x
模型封裝及模型可視化
3.3.3 模型封裝
model = paddle.Model(Sub_Pixel_CNN())
3.3.4 模型可視化
model.summary((1, 1, 100, 100))
3.3.5 模型訓練準備
損失函數選用:
這裏選擇了經常使用的的均方差損失函數:MSELoss,其表達式以下圖所示:
有興趣的小夥伴能夠嘗試一下使用PSMR做爲損失函數,可能效果會更好。
model.prepare( paddle.optimizer.Adam( learning_rate=0.001, parameters=model.parameters() ), paddle.nn.MSELoss() )
3.3.6 模型訓練:
# 啓動模型訓練,指定訓練數據集、訓練輪數、批次大小、日誌格式 model.fit(train_dataset, epochs=1, batch_size=16, verbose=1)
3.3.7 結果可視化
從咱們的預測數據集中抽1個張圖片來看看預測的效果,其中lowers是縮放的圖片,prediction是lowers通過卷積超分辨率以後的結果。
psmr_low: 30.381882136539197 psmr_pre: 29.4920122281961
4 .思考與總結
這篇論文發表以前,CNN網絡在超分辨率重建上就取得了很是好的效果,可是網絡結構複雜,不適合在移動端部署。這篇論文使用了一個結構十分簡單的網絡結構,能夠在視頻上實現實時超分辨率,給輕量級的超分辨率算法提供了一個很好的思路。由於時間關係,本項目尚未實現對視頻的實時處理。別急,下一個項目必定會有的!
最後,感謝飛槳和AI Studio深度學習開源平臺提供的支持。本項目全程使用AI Studio完成開發,簡直是窮學生黨的福音啊,V100是真的香!
下載安裝命令 ## CPU版本安裝命令 pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle ## GPU版本安裝命令 pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu
郵箱:(歡迎騷擾,一塊兒探討學習)
juntao.lu@connect.qut.edu.au
近期目標:拿到墨大AI方向研究生的offer(好像有點難度)
愛好倒騰,喜歡航模、航拍,夢想有朝一日實現財富自由,帶着個人小飛機自駕拍遍全國。疫情緣由暫時還在國內,歡迎南京的小夥伴找我面基。
如在使用過程當中有問題,可加入飛槳官方QQ羣進行交流:1108045677。
若是您想詳細瞭解更多飛槳的相關內容,請參閱如下文檔。
飛槳PaddlePaddle項目地址:
GitHub:
https://github.com/PaddlePaddle/PaddlePaddle
Gitee:
https://Gitee.com/PaddlePaddle/PaddlePaddle
飛槳官網地址:
https://www.paddlepaddle.org.cn/
本文分享 CSDN - Ralph Lu。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。