圖像匹配 | NCC 歸一化互相關損失 | 代碼 + 講解

  • 文章轉載自:微信公衆號「機器學習煉丹術」
  • 做者:煉丹兄(已受權)
  • 做者聯繫方式:微信cyx645016617(歡迎交流共同進步)

本次的內容主要講解NCCNormalized cross-correlation 歸一化互相關。python

兩張圖片是不是同一個內容,如今深度學習的方案天然是用神經網絡,比方說:孿生網絡的架構作人面識別等等;微信

在傳統的非參數方法中,常見的也有相關係數等。我在上一片文章voxelmorph的模型的學習中發現,在醫學圖像配準任務(不限於醫學),衡量兩個圖片類似的度量有一種叫作NCC的網絡

而這個NCC就是Normalized Cross-Correlation歸一化互相關係數。架構

1 互相關係數

若是你知道互相關係數,那麼你就能很好的理解歸一化互相關係數。機器學習

相關係數的計算公式以下:ide

\[r(X,Y) = \frac{Cov(X,Y)}{\sqrt{Var(X)Var(Y)}} \]

公式中的X,Y分別表示兩個圖片,\(Cov(X,Y)\)表示兩個圖片的協方差,\(Var(X)\)表示X自身的方差;函數

2 歸一化互相關NCC

若是把一張圖片,按照必定的像素,比方說9x9的一個框滑動,那麼就能夠把圖片分紅不少的9x9的小圖片,那麼NCC就是X,Y兩張大圖片中的對應的小圖片的互相關係數的平均值。學習

這裏看一下協方差的計算方式:
\(Cov(X,Y) = E[(X-E(X))(Y-E(Y))]\)spa

方差的計算爲:
\(Var(X) = E[(X-E(X))^2]\)code

其實NCC不難理解,可是如何用代碼計算呢?固然咱們能夠一行一行遍歷求解,可是這樣時間複雜度太高,因此咱們作好仍是選擇矩陣運算。

3 NCC損失函數的代碼

class NCC:
    """
    Local (over window) normalized cross correlation loss.
    """

    def __init__(self, win=None):
        self.win = win

    def loss(self, y_true, y_pred):

        I = y_true
        J = y_pred

        # get dimension of volume
        # assumes I, J are sized [batch_size, *vol_shape, nb_feats]
        ndims = len(list(I.size())) - 2
        assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims

        # set window size
        win = [9] * ndims if self.win is None else self.win

        # compute filters
        sum_filt = torch.ones([1, 1, *win]).to("cuda")

        pad_no = math.floor(win[0]/2)

        if ndims == 1:
            stride = (1)
            padding = (pad_no)
        elif ndims == 2:
            stride = (1,1)
            padding = (pad_no, pad_no)
        else:
            stride = (1,1,1)
            padding = (pad_no, pad_no, pad_no)

        # get convolution function
        conv_fn = getattr(F, 'conv%dd' % ndims)

        # compute CC squares
        I2 = I * I
        J2 = J * J
        IJ = I * J

        I_sum = conv_fn(I, sum_filt, stride=stride, padding=padding)
        J_sum = conv_fn(J, sum_filt, stride=stride, padding=padding)
        I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding)
        J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding)
        IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding)

        win_size = np.prod(win)
        u_I = I_sum / win_size
        u_J = J_sum / win_size

        cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size
        I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
        J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size

        cc = cross * cross / (I_var * J_var + 1e-5)

        return -torch.mean(cc)

這段代碼其實不是很好看懂,我思考了好久才明白。其中的關鍵就在於如何理解:

# compute CC squares
        I2 = I * I
        J2 = J * J
        IJ = I * J

        I_sum = conv_fn(I, sum_filt, stride=stride, padding=padding)
        J_sum = conv_fn(J, sum_filt, stride=stride, padding=padding)
        I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding)
        J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding)
        IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding)

        win_size = np.prod(win)
        u_I = I_sum / win_size
        u_J = J_sum / win_size

        cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size
        I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
        J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size

咱們能夠纔到,這個cross應該是協方差部分,I_var和J_var是方差部分。

咱們對協方差公式進行推導:\(Cov(X,Y) = E[(X-E(X))(Y-E(Y))]\)
\(=E[XY-XE(Y)-YE(X)+E(X)E(Y)]\)

這樣恰好和cross對應上。

  • IJ_sum = E[XY]
  • u_J * I_sum = E[XE(Y)]
  • u_I * u_J * win_size = E[E(X)E(Y)]

對方差公式進行推導:\(Var(X) = E[(X-E(X))^2]=E[X^2-2XE(X)+E(X)^2]\)

  • J2_sum = E(X^2)
  • 2 * u_J * J_sum = E[2XE(X)]
  • u_J * u_J * win_size = E[E(X)^2]
相關文章
相關標籤/搜索