[譯]使用 Python 實現接縫裁剪算法

接縫裁剪是一種新型的裁剪圖像的方式,它不會丟失圖像中的重要內容。這一般被稱之爲「內容感知」裁剪或圖像重定向。你能夠從這張照片中感覺一下這個算法:html

照片由 Unsplash 用戶 Pietro De Grandi 提供前端

變成下面這張:python

正如你所看到的,圖像中的很是重要內容 —— 船隻,都保留下來了。該算法去除了一些岩層和水(讓船看起來更靠近)。核心算法能夠參考 Shai Avidan 和 Ariel Shamir 的原始論文 Seam Carving for Content-Aware Image Resizing。在這篇文章中,我將展現如何在 Python 中基本實現該算法。android

概要

該算法的工做原理以下:ios

  1. 爲每一個像素分派一個能量值(energy)
  2. 找到能量最低的像素的 8 聯通區域
  3. 刪除該區域內全部的像素
  4. 重複 1-3,直到刪除所須要保留的行/列數

接下來,假設咱們只是嘗試裁剪圖像的寬度,即刪除列。對於刪除行來講也是相似的,至於緣由最後會說明。git

如下是 Python 代碼須要引入的包:github

import sys

import numpy as np
from imageio import imread, imwrite
from scipy.ndimage.filters import convolve

# tqdm 並非必需的,但它能夠向咱們展現一個漂亮的進度條
from tqdm import trange
複製代碼

能量圖

第一步是計算每一個像素的能量值,論文中定義了許多不一樣的可使用的能量函數。咱們來使用最基礎的那個:算法

這意味着什麼呢?I 表明圖像,因此這個式子告訴咱們,對於圖像中的每一個像素和每一個通道,咱們執行如下幾個步驟:後端

  • 找到 x 軸的偏導數
  • 找到 y 軸的偏導數
  • 將他們的絕對值求和

這就是該像素的能量值。那麼問題就來了,「你怎麼計算圖像的導數?」,維基百科上的 Image derivations(圖像導數)給咱們展現了許多不一樣的計算圖像導數的方法。咱們將使用 Sobel 濾波器。這是一個在圖像上的每一個通道上的計算的convolutional kernel(卷積核)。如下是圖像的兩個不一樣方向的過濾器:數組

直觀地說,咱們能夠認爲第一個濾波器是將每一個像素替換爲它上邊的值和下邊的值之差。第二個過濾器將每一個像素替換爲它右邊的值和左邊的值之差。這種濾波器捕捉到的是每一個像素相鄰所構成的 3x3 區域中像素的整體趨勢。事實上,這種方法與邊緣檢測算法也有關係。計算能量圖的方式很是簡單:

def calc_energy(img):
    filter_du = np.array([
        [1.0, 2.0, 1.0],
        [0.0, 0.0, 0.0],
        [-1.0, -2.0, -1.0],
    ])
    # 將一個 2D 的濾波器轉爲 3D 的濾波器,爲每一個通道設置相同的濾波器:R,G,B
    filter_du = np.stack([filter_du] * 3, axis=2)

    filter_dv = np.array([
        [1.0, 0.0, -1.0],
        [2.0, 0.0, -2.0],
        [1.0, 0.0, -1.0],
    ])
    # 將一個 2D 的濾波器轉爲 3D 的濾波器,爲每一個通道設置相同的濾波器:R,G,B
    filter_dv = np.stack([filter_dv] * 3, axis=2)

    img = img.astype('float32')
    convolved = np.absolute(convolve(img, filter_du)) + np.absolute(convolve(img, filter_dv))

    # 咱們將紅綠色藍三通道中的能量相加
    energy_map = convolved.sum(axis=2)

    return energy_map
複製代碼

可視化能量圖後,咱們能夠看到:

顯然,像天空和水的靜止部分這樣變化最小的區域,具備很是低的能量(暗的部分)。當咱們運行接縫裁剪算法的時候,被移除的線條通常都與圖像的這些部分緊密相關,同時試圖保留高能量部分(亮的部分)。

### 找到最小能量的接縫(seam)

咱們下一個目標就是找到一條從圖像頂部到圖像底部的能量最小的路徑。這條線必須是 8 聯通的:這意味着線中的每一個像素均可以他經過邊或叫角碰到線中的下一個像素。舉個例子,這就是下圖中的紅色線條:

因此咱們怎麼找到這條線呢?事實證實,這個問題能夠很好地使用動態規劃來解決!

讓咱們建立一個名爲 M 的 2D 數組 來存儲每一個像素的最小能量值。若是您不熟悉動態規劃,這簡單來講就是,從圖像頂部到該點的全部可能接縫(seam)中的最小能量即爲 M[i,j]。所以,M 的最後一行中就將包含從圖像頂部到底部的最小能量。咱們須要今後回溯以查找此接縫中存在的像素,因此咱們將保留這些值,存儲在名爲backtrack 的 2D 數組中。

def minimum_seam(img):
    r, c, _ = img.shape
    energy_map = calc_energy(img)

    M = energy_map.copy()
    backtrack = np.zeros_like(M, dtype=np.int)

    for i in range(1, r):
        for j in range(0, c):
            # 處理圖像的左邊緣,防止索引到 -1
            if j == 0:
                idx = np.argmin(M[i - 1, j:j + 2])
                backtrack[i, j] = idx + j
                min_energy = M[i - 1, idx + j]
            else:
                idx = np.argmin(M[i - 1, j - 1:j + 2])
                backtrack[i, j] = idx + j - 1
                min_energy = M[i - 1, idx + j - 1]

            M[i, j] += min_energy

    return M, backtrack
複製代碼

刪除最小能量的接縫中的像素

而後咱們就能夠刪除有着最低能量的接縫中的像素,返回新的圖片:

def carve_column(img):
    r, c, _ = img.shape

    M, backtrack = minimum_seam(img)

    # 建立一個(r,c)矩陣,全部值都爲 True
    # 咱們將刪除圖像中矩陣裏全部爲 False 的對應的像素
    mask = np.ones((r, c), dtype=np.bool)

    # 找到 M 最後一行中最小元素的那一列的索引
    j = np.argmin(M[-1])

    for i in reversed(range(r)):
        # 標記這個像素以後須要刪除
        mask[i, j] = False
        j = backtrack[i, j]

    # 由於圖像是三通道的,咱們將 mask 轉爲 3D 的
    mask = np.stack([mask] * 3, axis=2)

    # 刪除 mask 中全部爲 False 的位置所對應的像素,並將
    # 他們從新調整爲新圖像的尺寸
    img = img[mask].reshape((r, c - 1, 3))

    return img
複製代碼

對每列重複操做

全部的基礎工做都已作完了!如今,咱們只要一次次地運行 carve_column 函數,直到咱們刪除到了所需的列數。咱們再建立一個 crop_c 函數,圖像和縮放因子做爲輸入。若是圖像的尺寸爲(300,600),而且咱們想要將其減少到(150,600),scale_c 設置爲 0.5 便可。

def crop_c(img, scale_c):
    r, c, _ = img.shape
    new_c = int(scale_c * c)

    for i in trange(c - new_c): # 若是你不想用 tqdm,這裏將 trange 改成 range
        img = carve_column(img)

    return img
複製代碼

將它們合在一塊兒

咱們能夠添加一個 main 函數,讓代碼能夠經過命令行調用:

def main():
    scale = float(sys.argv[1])
    in_filename = sys.argv[2]
    out_filename = sys.argv[3]

    img = imread(in_filename)
    out = crop_c(img, scale)
    imwrite(out_filename, out)

if __name__ == '__main__':
    main()
複製代碼

而後運行這段代碼:

python carver.py 0.5 image.jpg cropped.jpg
複製代碼

cropped.jpg 如今應該顯示如下這樣的圖像:

![]https://user-gold-cdn.xitu.io/2018/7/12/1648d13cb3f0ab58?w=400&h=533&f=jpeg&s=57795)

行應該怎麼處理呢?

而後,咱們能夠開始研究怎麼修改咱們的循環來換個方向處理數據。或者...只需旋轉圖像就能夠運行 crop_c

def crop_r(img, scale_r):
    img = np.rot90(img, 1, (0, 1))
    img = crop_c(img, scale_r)
    img = np.rot90(img, 3, (0, 1))
    return img
複製代碼

將這段代碼添加到 main 函數中,如今咱們也能夠裁剪行!

def main():
    if len(sys.argv) != 5:
        print('usage: carver.py <r/c> <scale> <image_in> <image_out>', file=sys.stderr)
        sys.exit(1)

    which_axis = sys.argv[1]
    scale = float(sys.argv[2])
    in_filename = sys.argv[3]
    out_filename = sys.argv[4]

    img = imread(in_filename)

    if which_axis == 'r':
        out = crop_r(img, scale)
    elif which_axis == 'c':
        out = crop_c(img, scale)
    else:
        print('usage: carver.py <r/c> <scale> <image_in> <image_out>', file=sys.stderr)
        sys.exit(1)
    
    imwrite(out_filename, out)
複製代碼

運行代碼:

python carver.py r 0.5 image2.jpg cropped.jpg
複製代碼

而後咱們就能夠把這張圖:

Photo by Brent Cox on Unsplash

變成這樣:

總結

我但願你是愉快而又收穫地讀到這裏的。我很享受實現這篇論文的過程,並打算構建一個這個算法更快的版本。好比說,使用相同的計算過的圖像接縫去除多個接縫。在個人實驗中,這可使算法更快,每次迭代能夠幾乎線性地移除接縫,但質量明顯降低。另外一個優化是計算 GPU 上的能量圖,在這裏探討的

這是完整的程序:

#!/usr/bin/env python

""" Usage: python carver.py <r/c> <scale> <image_in> <image_out> Copyright 2018 Karthik Karanth, MIT License """

import sys

from tqdm import trange
import numpy as np
from imageio import imread, imwrite
from scipy.ndimage.filters import convolve

def calc_energy(img):
    filter_du = np.array([
        [1.0, 2.0, 1.0],
        [0.0, 0.0, 0.0],
        [-1.0, -2.0, -1.0],
    ])
    # 將一個 2D 的濾波器轉爲 3D 的濾波器,爲每一個通道設置相同的濾波器:R,G,B
    filter_du = np.stack([filter_du] * 3, axis=2)

    filter_dv = np.array([
        [1.0, 0.0, -1.0],
        [2.0, 0.0, -2.0],
        [1.0, 0.0, -1.0],
    ])
    # 將一個 2D 的濾波器轉爲 3D 的濾波器,爲每一個通道設置相同的濾波器:R,G,B
    filter_dv = np.stack([filter_dv] * 3, axis=2)

    img = img.astype('float32')
    convolved = np.absolute(convolve(img, filter_du)) + np.absolute(convolve(img, filter_dv))

    # 咱們將紅綠色藍三通道中的能量相加
    energy_map = convolved.sum(axis=2)

    return energy_map

def crop_c(img, scale_c):
    r, c, _ = img.shape
    new_c = int(scale_c * c)

    for i in trange(c - new_c):
        img = carve_column(img)

    return img

def crop_r(img, scale_r):
    img = np.rot90(img, 1, (0, 1))
    img = crop_c(img, scale_r)
    img = np.rot90(img, 3, (0, 1))
    return img

def carve_column(img):
    r, c, _ = img.shape

    M, backtrack = minimum_seam(img)
    mask = np.ones((r, c), dtype=np.bool)

    j = np.argmin(M[-1])
    for i in reversed(range(r)):
        mask[i, j] = False
        j = backtrack[i, j]

    mask = np.stack([mask] * 3, axis=2)
    img = img[mask].reshape((r, c - 1, 3))
    return img

def minimum_seam(img):
    r, c, _ = img.shape
    energy_map = calc_energy(img)

    M = energy_map.copy()
    backtrack = np.zeros_like(M, dtype=np.int)

    for i in range(1, r):
        for j in range(0, c):
            # 處理圖像的左邊緣,防止索引到 -1
            if j == 0:
                idx = np.argmin(M[i-1, j:j + 2])
                backtrack[i, j] = idx + j
                min_energy = M[i-1, idx + j]
            else:
                idx = np.argmin(M[i - 1, j - 1:j + 2])
                backtrack[i, j] = idx + j - 1
                min_energy = M[i - 1, idx + j - 1]

            M[i, j] += min_energy

    return M, backtrack

def main():
    if len(sys.argv) != 5:
        print('usage: carver.py <r/c> <scale> <image_in> <image_out>', file=sys.stderr)
        sys.exit(1)

    which_axis = sys.argv[1]
    scale = float(sys.argv[2])
    in_filename = sys.argv[3]
    out_filename = sys.argv[4]

    img = imread(in_filename)

    if which_axis == 'r':
        out = crop_r(img, scale)
    elif which_axis == 'c':
        out = crop_c(img, scale)
    else:
        print('usage: carver.py <r/c> <scale> <image_in> <image_out>', file=sys.stderr)
        sys.exit(1)
    
    imwrite(out_filename, out)

if __name__ == '__main__':
    main()
複製代碼

修改於(2018 年 5 月 5 日): 正如一個熱心的 reddit 用戶所說,經過使用 numba 來加速計算繁重的功能,能夠很容易的獲得幾十倍的性能提高。要想體驗 numba,只要在函數 carve_columnminimum_seam 以前加上 @numba.jit。就像下面這樣:

@numba.jit
def carve_column(img):

@numba.jit
def minimum_seam(img):
複製代碼

若是發現譯文存在錯誤或其餘須要改進的地方,歡迎到 掘金翻譯計劃 對譯文進行修改並 PR,也可得到相應獎勵積分。文章開頭的 本文永久連接 即爲本文在 GitHub 上的 MarkDown 連接。


掘金翻譯計劃 是一個翻譯優質互聯網技術文章的社區,文章來源爲 掘金 上的英文分享文章。內容覆蓋 AndroidiOS前端後端區塊鏈產品設計人工智能等領域,想要查看更多優質譯文請持續關注 掘金翻譯計劃官方微博知乎專欄

相關文章
相關標籤/搜索