qwe框架- CNN 實現

CNN實現

概述

我在qwe中有兩種,第一種是按照Ng課程中的寫法,多層循環嵌套獲得每次的「小方格」,而後WX+b,這樣的作法是最簡單,直觀。可是效率極其慢。基本跑個10張之內圖片都會卡的要死。git

第二種方法是使用img2col,將其轉換爲對應的矩陣,而後直接作一次矩陣乘法運算。github

先看第一種ide

def forward(self, X):
        m, n_H_prev, n_W_prev, n_C_prev = X.shape
        (f, f, n_C_prev, n_C) = self.W.shape
        n_H = int((n_H_prev - f + 2 * self.pad) / self.stride) + 1
        n_W = int((n_W_prev - f + 2 * self.pad) / self.stride) + 1
        n_H, n_W, n_C = self.output_size

        Z = np.zeros((m, n_H, n_W, n_C))
        X_pad = zero_pad(X, self.pad)
        for i in range(m):
            for h in range(n_H):
                for w in range(n_W):
                    for c in range(n_C):
                        vert_start = h * self.stride
                        vert_end = vert_start + f
                        horiz_start = w * self.stride
                        horiz_end = horiz_start + f
                        A_slice_prev =X_pad[i,vert_start:vert_end, horiz_start:horiz_end, :]
                        Z[i,h,w,c] = conv_single_step(A_slice_prev, self.W[...,c], self.b[...,c])

def conv_single_step(X, W, b):
    # 對一個裁剪圖像進行卷積
    # X.shape = f, f, prev_channel_size
    return np.sum(np.multiply(X, W) + b)

對於m,n_H,n_W,n_C循環就是取得裁剪小方塊,能夠看到這裏的計算複雜度m * n_H * n_W * n_C * (f*f的矩陣計算).net

第二種方法,先轉換成大矩陣,再進行一次矩陣運算,至關於節省了屢次小矩陣運算時間,這仍是很可觀的,能查個幾十倍的速度。3d

img2col原理很簡單,詳情可參考caffe im2colcode

就是循環將每一部分都拉長成一維矩陣拼湊起來。blog

對於CNN來講,H就是要計算方塊的個數即m(樣本數) n_H(最終生成圖像行數)n_W(最終生成圖像列數),W就是f(核kernel長)f(核寬)*(輸入樣本通道輸)圖片

而後還要把參數矩陣W也拉成這個樣子,H就是f(核長)f(核寬)(輸入樣本通道輸),W列數就是核數kernel_sizeip

以下圖get


def img2col(X, pad, stride, f):
    pass
    ff = f * f
    m, n_H_prev, n_W_prev, n_C_prev= X.shape
    n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
    n_W = int((n_W_prev - f + 2 * pad) / stride) + 1
    Z = np.zeros((m * n_H * n_W, f * f * n_C_prev))
    X_pad = np.pad(X, ((0, 0), (pad, pad), (pad, pad), (0, 0)), 'constant', constant_values=0)
    row = -1

    for i in range(m):
        for h in range(n_H):
            for w in range(n_W):
                row += 1
                vert_start = h * stride
                horiz_start = w * stride
                for col in range(f * f * n_C_prev):
                    t = col // n_C_prev
                    hh = t // f
                    ww = t % f
                    cc = col % n_C_prev
                    Z[row, col] = X_pad[i, vert_start + hh, horiz_start + ww, cc]

def speed_forward(model, X):
    W = model.W
    b = model.b
    stride = model.stride
    pad = model.pad
    (n_C_prev, f, f, n_C) = W.shape
    m, n_H_prev, n_W_prev, n_C_prev = X.shape

    n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
    n_W = int((n_W_prev - f + 2 * pad) / stride) + 1

    # WW = W.swapaxes(2,1)
    # WW = WW.swapaxes(1,0)

    XX = img2col(X, pad, stride, f)
    # WW = WW.reshape(f*f*n_C_prev, n_C)
    WW = W.reshape(f*f*n_C_prev, n_C)
    model.XX = XX
    model.WW = WW

    Z = np.dot(XX, WW) + b
    return Z.reshape(m, n_H, n_W, n_C)

這種耗時操做,最好使用Cython擴展來寫,否則速度仍是不夠理想。Cython擴展代碼code

反向傳播同理,具體代碼參考

github

相關文章
相關標籤/搜索