利用CNN卷積神經網絡進行訓練時,進行完卷積運算,還須要接着進行Max pooling池化操做,目的是在儘可能不丟失圖像特徵前期下,對圖像進行downsampling。html
首先看下max pooling的具體操做:整個圖片被不重疊的分割成若干個一樣大小的小塊(pooling size)。每一個小塊內只取最大的數字,再捨棄其餘節點後,保持原有的平面結構得出 output。web
相應的,對於多個feature map,操做以下,本來64張224X224的圖像,通過Max Pooling後,變成了64張112X112的圖像,從而實現了downsampling的目的。網絡
爲何能夠這樣?這裏利用到一個特性:平移不變性(translation invariant),結論的公式證實還無從考證,不過從下面的實例能夠側面證實這點:ide
右上角爲3副橫折位置不同的圖像,分別同左上角的卷積核進行運算,而後再進行3X3大小池化操做之後,咱們發現最後都能獲得相同的識別結果。還有人更通俗理解卷積後再進行池化運算獲得相同的結果,就比如牛逼的球隊分到不一樣的組獲得得到相同的比賽結果同樣。函數
除了Max Pooling,還有一些其它的池化操做,例如:SUM pooling、AVE pooling、MOP pooling、CROW pooling和RMAC pooling等,這裏再也不進行介紹,見末尾參考文章連接。spa
下面利用tensorflow模塊的max_pool函數,實現Max pooling操做:code
# 導入tensorflow庫 import tensorflow as tf # 定義2個行爲4,列爲4,通道爲1的數據集 batches = 2 height = 4 width = 4 channes = 1 dataset = tf.Variable( [ [ [[1.0],[2.0],[5.0],[6.0]], [[3.0],[4.0],[7.0],[8.0]], [[9.0],[10.0],[13.0],[14.0]], [[11.0],[12.0],[15.0],[16.0]] ], [ [[17.0],[18.0],[21.0],[22.0]], [[19.0],[20.0],[23.0],[24.0]], [[25.0],[26.0],[29.0],[30.0]], [[27.0],[28.0],[31.0],[32.0]] ] ]) # 定義Max pooling操做運算,重點理解下ksize和strides兩個參數的含義: # ksize表示不一樣維度Max pooling的大小,因爲batches和channels兩個維度不須要進行Max pooling,因此爲1 # strides表示下個Max pooling位置的跳躍大小,同理,因爲batches和channels兩個維度不須要進行Max pooling,因此爲1 X = tf.placeholder(dtype="float",shape=[None,height,width,channes]) data_max_pool = tf.nn.max_pool(value=X,ksize=[1,2,2,1],strides=[1,2,2,1],padding="VALID") # 開始進行tensorflow計算圖運算 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) input = sess.run(dataset) output = sess.run(data_max_pool,feed_dict = {X:input}) print(input) print("===============================") print(output) # 輸入: # [ # [ # [[ 1.] [ 2.] [ 5.] [ 6.]] # [[ 3.] [ 4.] [ 7.] [ 8.]] # [[ 9.] [10.] [13.] [14.]] # [[11.] [12.] [15.] [16.]] # ] # # [ # [[17.] [18.] [21.] [22.]] # [[19.] [20.] [23.] [24.]] # [[25.] [26.] [29.] [30.]] # [[27.] [28.] [31.] [32.]] # ] # ] # # =============================== # 輸出: # [ # [ # [[ 4.] [ 8.]] # [[12.] [16.]] # ] # [ # [[20.] [24.]] # [[28.] [32.]] # ] # ]
參考文章:CNN中的maxpool究竟是什麼原理?htm