基於飛槳復現 GLCLC 模型,對殘次圖片實現圖像補全

【飛槳開發者說】侯繼旭,海南師範大學自動化本科在讀,PPDE飛槳開發者技術專家,研究方向爲目標檢測、對抗生成網絡等php

本次復現使用的數據集是CelebA人臉數據集,這是一個大規模的人臉屬性數據集,是由香港中文大學湯曉鷗教授實驗室公佈的大型人臉識別數據集,擁有超過20萬張名人圖像,已下載放置在此項目的數據集中,人臉屬性有40多種。

本文項目代碼github地址:

https://github.com/Eric-Hjx/P...

模型摘要git

在此篇論文中,做者們提出了Globally and Locally Consistent Image Completion方法,可使得圖像的缺失部分自動補全,局部和整圖保持一致。做者經過全卷積網絡,能夠補全圖片中任何形狀的缺失,爲了保持補全後的圖像與原圖的一致性,做者使用全局(整張圖片)和局部(缺失補所有分)兩種鑑別器來訓練。全局鑑別器查看整個圖像以評估它是否做爲總體是連貫的,而局部鑑別器僅查看以完成區域爲中心的小區域來確保所生成的補丁的局部一致性。github

接着對圖像補全網絡訓練以欺騙兩個內容鑑別器網絡,這要求它生成整體以及細節上與真實沒法區分的圖像。咱們證實了咱們的方法能夠用來完成各類各樣的場景。此外,與PatchMatch等基於補丁的方法相比,咱們的方法能夠生成圖像中未出現的碎片,這使咱們可以天然地完成具備熟悉且高度特定的結構(如面部)的對象的圖像。算法

該論文的方法,徹底以卷積網絡做爲基礎,使用了GAN網絡的思路,設計了兩部分(三個網絡),一部分用於生成圖像,即補全網絡,一部分用於鑑別生成圖像是否與原圖像一致,即全局鑑別器和局部鑑別器。網絡結構圖以下所示:網絡

網絡介紹:

  1. 補全網絡:補全網絡是徹底卷積的,目的是用來修復圖像。
  2. 全局鑑別器:以完整的圖像做爲輸入,識別場景的全局一致性。
  3. 局部鑑別器:只關注完成區域周圍的一個小區域,以判斷更詳細的外觀質量。

基於飛槳實現GLCLC算法app

下面咱們基於飛槳開源深度學習框架動手實現 GLCLC 算法,介紹神經網絡代碼實現內容,主要使用了卷積、反捲積、空洞卷積、正則、激活函數等方法搭建了補全網絡及鑑別網絡。框架

1. 補全網絡結構

補全網絡部分,做者採用12層卷積網絡對輸入圖像進行encoding,獲得一張原圖16分之一大小的網格。而後再對該網格採用4層卷積網絡進行decoding。爲了保證生成區域儘可能不模糊,文中下降分辨率的操做是使用strided convolution 的方式進行的,並且只用了兩次,將圖片的size 變爲原來的四分之一。同時在中間層還使用了空洞卷積來增大感覺野,在儘可能獲取更大範圍內的圖像信息的同時不損失額外的信息,從而獲得復原圖像。下表爲補全網絡各層參數分佈狀況。dom

輸入爲RGB圖像與二進制掩碼(須要填充的區域以1填充)的組合圖像;輸出爲RGB圖像。ide

1.  # 搭建補全網絡
    
2.  def generator(x):
    
3.      # conv1
    
4.      conv1 = fluid.layers.conv2d(input=x,num_filters=64,filter_size=5,dilation=1,stride=1,padding='SAME',name='generator_conv1',data_format='NHWC')
    
5.      conv1 = fluid.layers.batch_norm(conv1, momentum=0.99, epsilon=0.001)
    
6.      conv1 = fluid.layers.relu(conv1, name=None)
    
7.      # conv2
    
8.      conv2 = fluid.layers.conv2d(input=conv1,num_filters=128,filter_size=3,dilation=1,stride=2,padding='SAME',name='generator_conv2',data_format='NHWC')
    
9.      conv2 = fluid.layers.batch_norm(conv2, momentum=0.99, epsilon=0.001)
    
10.      conv2 = fluid.layers.relu(conv2, name=None)
    
11.      # conv3
    
12.      conv3 = fluid.layers.conv2d(input=conv2,num_filters=128,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv3',data_format='NHWC')
    
13.      conv3 = fluid.layers.batch_norm(conv3, momentum=0.99, epsilon=0.001)
    
14.      conv3 = fluid.layers.relu(conv3, name=None)
    
15.      # conv4
    
16.      conv4 = fluid.layers.conv2d(input=conv3,num_filters=256,filter_size=3,dilation=1,stride=2,padding='SAME',name='generator_conv4',data_format='NHWC')
    
17.      conv4 = fluid.layers.batch_norm(conv4, momentum=0.99, epsilon=0.001)
    
18.      conv4 = fluid.layers.relu(conv4, name=None)
    
19.      # conv5
    
20.      conv5 = fluid.layers.conv2d(input=conv4,num_filters=256,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv5',data_format='NHWC')
    
21.      conv5 = fluid.layers.batch_norm(conv5, momentum=0.99, epsilon=0.001)
    
22.      conv5 = fluid.layers.relu(conv5, name=None)
    
23.      # conv6
    
24.      conv6 = fluid.layers.conv2d(input=conv5,num_filters=256,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv6',data_format='NHWC')
    
25.      conv6 = fluid.layers.batch_norm(conv6, momentum=0.99, epsilon=0.001)
    
26.      conv6 = fluid.layers.relu(conv6, name=None)
    
27.      # 空洞卷積
    
28.      # dilated1
    
29.      dilated1 = fluid.layers.conv2d(input=conv6,num_filters=256,filter_size=3,dilation=2,padding='SAME',name='generator_dilated1',data_format='NHWC')
    
30.      dilated1 = fluid.layers.batch_norm(dilated1, momentum=0.99, epsilon=0.001)
    
31.      dilated1 = fluid.layers.relu(dilated1, name=None)
    
32.      # dilated2
    
33.      dilated2 = fluid.layers.conv2d(input=dilated1,num_filters=256,filter_size=3,dilation=4,padding='SAME',name='generator_dilated2',data_format='NHWC') #stride=1
    
34.      dilated2 = fluid.layers.batch_norm(dilated2, momentum=0.99, epsilon=0.001)
    
35.      dilated2 = fluid.layers.relu(dilated2, name=None)
    
36.      # dilated3
    
37.      dilated3 = fluid.layers.conv2d(input=dilated2,num_filters=256,filter_size=3,dilation=8,padding='SAME',name='generator_dilated3',data_format='NHWC')
    
38.      dilated3 = fluid.layers.batch_norm(dilated3, momentum=0.99, epsilon=0.001)
    
39.      dilated3 = fluid.layers.relu(dilated3, name=None)
    
40.      # dilated4
    
41.      dilated4 = fluid.layers.conv2d(input=dilated3,num_filters=256,filter_size=3,dilation=16,padding='SAME',name='generator_dilated4',data_format='NHWC')
    
42.      dilated4 = fluid.layers.batch_norm(dilated4, momentum=0.99, epsilon=0.001)
    
43.      dilated4 = fluid.layers.relu(dilated4, name=None)
    
44.      # conv7
    
45.      conv7 = fluid.layers.conv2d(input=dilated4,num_filters=256,filter_size=3,dilation=1,name='generator_conv7',data_format='NHWC')
    
46.      conv7 = fluid.layers.batch_norm(conv7, momentum=0.99, epsilon=0.001)
    
47.      conv7 = fluid.layers.relu(conv7, name=None)
    
48.      # conv8
    
49.      conv8 = fluid.layers.conv2d(input=conv7,num_filters=256,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv8',data_format='NHWC')
    
50.      conv8 = fluid.layers.batch_norm(conv8, momentum=0.99, epsilon=0.001)
    
51.      conv8 = fluid.layers.relu(conv8, name=None)
    
52.      # deconv1
    
53.      deconv1 = fluid.layers.conv2d_transpose(input=conv8, num_filters=128, output_size=[64,64],stride = 2,name='generator_deconv1',data_format='NHWC')
    
54.      deconv1 = fluid.layers.batch_norm(deconv1, momentum=0.99, epsilon=0.001)
    
55.      deconv1 = fluid.layers.relu(deconv1, name=None)
    
56.      # conv9
    
57.      conv9 = fluid.layers.conv2d(input=deconv1,num_filters=128,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv9',data_format='NHWC')
    
58.      conv9 = fluid.layers.batch_norm(conv9, momentum=0.99, epsilon=0.001)
    
59.      conv9 = fluid.layers.relu(conv9, name=None)
    
60.      # deconv2
    
61.      deconv2 = fluid.layers.conv2d_transpose(input=conv9, num_filters=64, output_size=[128,128],stride = 2,name='generator_deconv2',data_format='NHWC')
    
62.      deconv2 = fluid.layers.batch_norm(deconv2, momentum=0.99, epsilon=0.001)
    
63.      deconv2 = fluid.layers.relu(deconv2, name=None)
    
64.      # conv10
    
65.      conv10 = fluid.layers.conv2d(input=deconv2,num_filters=32,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv10',data_format='NHWC')
    
66.      conv10 = fluid.layers.batch_norm(conv10, momentum=0.99, epsilon=0.001)
    
67.      conv10 = fluid.layers.relu(conv10, name=None)
    
68.      # conv11
    
69.      x = fluid.layers.conv2d(input=conv10,num_filters=3,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv11',data_format='NHWC')
    
70.      x = fluid.layers.tanh(x)
    
71.      return x

2. 內容鑑別器

內容鑑別器分爲了兩個部分,一個全局鑑別器(Global Discriminator)以及一個局部鑑別器(Local Discriminator)。全局鑑別器是將一張完整的圖像做爲輸入數據,對圖像的全局一致性作出判斷;局部鑑別器僅在以填充區域爲中心的原圖像四分之一大小區域上觀測,對此部分圖像的一致性作出判斷。經過採用上述兩個不一樣的鑑別器,可使得最終的網絡,不但能夠對圖像全局一致性作判斷,而且可以經過局部鑑別方法,優化生成圖的細節,最終能產生更好的圖片填充效果。函數

在原文中,做者設定的全局鑑別網絡輸入是256X256X3的圖片,局部網絡輸入是128X128X3的圖片。原始論文中,全局網絡和局部網絡都會經過使用5X5的卷積層、2X2的stride下降圖像分辨率,經過全鏈接,分別獲得一個1024維的向量。而後,做者將全局和局部兩個鑑別器的輸出鏈接成一個2048維向量,再經過一個全鏈接,而後用sigmoid函數對總體的圖像的一致性進行打分判別。但在本次實驗,爲了能下降訓練難度,設定全局鑑別網絡輸入是128X128X3的圖片,局部網絡輸入是64X64X3的圖片。

1.  # 搭建內容鑑別器
    
2.  def discriminator(global_x, local_x):
    
3.      def global_discriminator(x):
    
4.          # conv1
    
5.          conv1 = fluid.layers.conv2d(input=x,num_filters=64,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv1',data_format='NHWC')
    
6.          conv1 = fluid.layers.batch_norm(conv1, momentum=0.99, epsilon=0.001)
    
7.          conv1 = fluid.layers.relu(conv1, name=None)
    
8.          # conv2
    
9.          conv2 = fluid.layers.conv2d(input=conv1,num_filters=128,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv2',data_format='NHWC')
    
10.          conv2 = fluid.layers.batch_norm(conv2, momentum=0.99, epsilon=0.001)
    
11.          conv2 = fluid.layers.relu(conv2, name=None)
    
12.          # conv3
    
13.          conv3 = fluid.layers.conv2d(input=conv2,num_filters=256,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv3',data_format='NHWC')
    
14.          conv3 = fluid.layers.batch_norm(conv3, momentum=0.99, epsilon=0.001)
    
15.          conv3 = fluid.layers.relu(conv3, name=None)
    
16.          # conv4
    
17.          conv4 = fluid.layers.conv2d(input=conv3,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv4',data_format='NHWC')
    
18.          conv4 = fluid.layers.batch_norm(conv4, momentum=0.99, epsilon=0.001)
    
19.          conv4 = fluid.layers.relu(conv4, name=None)
    
20.          # conv5
    
21.          conv5 = fluid.layers.conv2d(input=conv4,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv5',data_format='NHWC')
    
22.          conv5 = fluid.layers.batch_norm(conv5, momentum=0.99, epsilon=0.001)
    
23.          conv5 = fluid.layers.relu(conv5, name=None)
    
24.          # conv6
    
25.          conv6 = fluid.layers.conv2d(input=conv5,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv6',data_format='NHWC')
    
26.          conv6 = fluid.layers.batch_norm(conv6, momentum=0.99, epsilon=0.001)
    
27.          conv6 = fluid.layers.relu(conv6, name=None)
    
28.          # fc
    
29.          x = fluid.layers.fc(input=conv6, size=1024,name='discriminator_global_fc1')
    
30.          return x
    

32.      def local_discriminator(x):
    
33.          # conv1
    
34.          conv1 = fluid.layers.conv2d(input=x,num_filters=64,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv1',data_format='NHWC')
    
35.          conv1 = fluid.layers.batch_norm(conv1, momentum=0.99, epsilon=0.001)
    
36.          conv1 = fluid.layers.relu(conv1, name=None)
    
37.          # conv2
    
38.          conv2 = fluid.layers.conv2d(input=conv1,num_filters=128,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv2',data_format='NHWC')
    
39.          conv2 = fluid.layers.batch_norm(conv2, momentum=0.99, epsilon=0.001)
    
40.          conv2 = fluid.layers.relu(conv2, name=None)
    
41.          # conv3
    
42.          conv3 = fluid.layers.conv2d(input=conv2,num_filters=256,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv3',data_format='NHWC')
    
43.          conv3 = fluid.layers.batch_norm(conv3, momentum=0.99, epsilon=0.001)
    
44.          conv3 = fluid.layers.relu(conv3, name=None)
    
45.          # conv4
    
46.          conv4 = fluid.layers.conv2d(input=conv3,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv4',data_format='NHWC')
    
47.          conv4 = fluid.layers.batch_norm(conv4, momentum=0.99, epsilon=0.001)
    
48.          conv4 = fluid.layers.relu(conv4, name=None)
    
49.          # conv5
    
50.          conv5 = fluid.layers.conv2d(input=conv4,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv5',data_format='NHWC')
    
51.          conv5 = fluid.layers.batch_norm(conv5, momentum=0.99, epsilon=0.001)
    
52.          conv5 = fluid.layers.relu(conv5, name=None)
    
53.          # fc
    
54.          x = fluid.layers.fc(input=conv5, size=1024,name='discriminator_lobal_fc1')
    
55.          return x
    

57.      global_output = global_discriminator(global_x)
    
58.      local_output = local_discriminator(local_x)
    
59.      print('global_output',global_output.shape)
    
60.      print('local_output',local_output.shape)
    
61.      output = fluid.layers.concat([global_output, local_output], axis=1)
    
62.      output = fluid.layers.fc(output, size=1,name='discriminator_concatenation_fc1')
    

64.      return output

3. 損失函數

生成網絡使用weighted Mean Squared Error (MSE)做爲損失函數,計算原圖與生成圖像像素之間的差別,表達式以下所示:

鑑別器網絡使用GAN損失函數,其目標是最大化生成圖像和原始圖像的類似機率,表達式以下所示:

最後結合二者損失,造成下式:

網絡訓練

原文做者使用4個K80 GPU,使用的輸入圖像大小是256*256,訓練了2個月才訓練完成。

本項目爲了縮短訓練時間,僅採用了此論文核心思想、網絡結構、優化目標等,並對訓練方式及部分細節作了簡化。使用的輸入圖像大小:128*128,訓練方式設定爲:先訓練生成器再將生成器和判別器一塊兒訓練。

1.  # 生成器優先迭代次數
    
2.  NUM_TRAIN_TIMES_OF_DG = 100
    
3.  # 總迭代輪次
    
4.  epoch = 200
    

6.  step_num = int(len(x_train) / BATCH_SIZE)
    

8.  np.random.shuffle(x_train)
    

10.  for pass_id in range(epoch):
    
11.      # 訓練生成器
    
12.      if pass_id <= NUM_TRAIN_TIMES_OF_DG:
    
13.          g_loss_value = 0
    
14.          for i in tqdm.tqdm(range(step_num)):
    
15.              x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
    
16.              points_batch, mask_batch = get_points()
    
17.              # print(x_batch.shape)
    
18.              # print(mask_batch.shape)
    
19.              dg_loss_n = exe.run(dg_program,
    
20.                                   feed={'x': x_batch, 
    
21.                                          'mask':mask_batch,},
    
22.                                   fetch_list=[dg_loss])[0]
    
23.              g_loss_value += dg_loss_n
    
24.          print('Pass_id:{}, Completion loss: {}'.format(pass_id, g_loss_value))
    

26.          np.random.shuffle(x_test)
    
27.          x_batch = x_test[:BATCH_SIZE]
    

29.          completion_n = exe.run(dg_program, 
    
30.                          feed={'x': x_batch, 
    
31.                                  'mask': mask_batch,},
    
32.                          fetch_list=[completion])[0][0]
    
33.          # 修復圖片
    
34.          sample = np.array((completion_n + 1) * 127.5, dtype=np.uint8)
    
35.          # 原圖
    
36.          x_im = np.array((x_batch[0] + 1) * 127.5, dtype=np.uint8)
    
37.          # 挖空洞輸入圖
    
38.          input_im_data = x_im * (1 - mask_batch[0])
    
39.          input_im = np.array(input_im_data + np.ones_like(x_im) * mask_batch[0] * 255, dtype=np.uint8)
    
40.          output_im = np.concatenate((x_im,input_im,sample),axis=1)
    
41.          #print(output_im.shape)
    
42.          cv2.imwrite('./output/pass_id:{}.jpg'.format(pass_id), cv2.cvtColor(output_im, cv2.COLOR_RGB2BGR))
    
43.          # 保存模型
    
44.          save_pretrain_model_path = 'models/'
    
45.          # 建立保持模型文件目錄
    
46.          #os.makedirs(save_pretrain_model_path)
    
47.          fluid.io.save_params(executor=exe, dirname=save_pretrain_model_path, main_program=dg_program)
    

49.      # 生成器判斷器一塊兒訓練
    
50.      else:
    
51.          g_loss_value = 0
    
52.          d_loss_value = 0
    
53.          for i in tqdm.tqdm(range(step_num)):
    
54.              x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
    
55.              points_batch, mask_batch = get_points()
    
56.              dg_loss_n = exe.run(dg_program,
    
57.                                   feed={'x': x_batch, 
    
58.                                          'mask':mask_batch,},
    
59.                                   fetch_list=[dg_loss])[0]
    
60.              g_loss_value += dg_loss_n
    

62.              completion_n = exe.run(dg_program, 
    
63.                                  feed={'x': x_batch, 
    
64.                                          'mask': mask_batch,},
    
65.                                  fetch_list=[completion])[0]
    
66.              local_x_batch = []
    
67.              local_completion_batch = []
    
68.              for i in range(BATCH_SIZE):
    
69.                  x1, y1, x2, y2 = points_batch[i]
    
70.                  local_x_batch.append(x_batch[i][y1:y2, x1:x2, :])
    
71.                  local_completion_batch.append(completion_n[i][y1:y2, x1:x2, :])
    
72.              local_x_batch = np.array(local_x_batch)
    
73.              local_completion_batch = np.array(local_completion_batch)
    
74.              d_loss_n  = exe.run(d_program,
    
75.                                  feed={'x': x_batch, 'mask': mask_batch, 'local_x': local_x_batch, 'global_completion': completion_n, 'local_completion': local_completion_batch},
    
76.                                  fetch_list=[d_loss])[0]
    
77.              d_loss_value += d_loss_n
    
78.          print('Pass_id:{}, Completion loss: {}'.format(pass_id, g_loss_value))
    
79.          print('Pass_id:{}, Discriminator loss: {}'.format(pass_id, d_loss_value))
    

81.          np.random.shuffle(x_test)
    
82.          x_batch = x_test[:BATCH_SIZE]
    
83.          completion_n = exe.run(dg_program, 
    
84.                          feed={'x': x_batch, 
    
85.                                  'mask': mask_batch,},
    
86.                          fetch_list=[completion])[0][0]
    
87.          # 修復圖片
    
88.          sample = np.array((completion_n + 1) * 127.5, dtype=np.uint8)
    
89.          # 原圖
    
90.          x_im = np.array((x_batch[0] + 1) * 127.5, dtype=np.uint8)
    
91.          # 挖空洞輸入圖
    
92.          input_im_data = x_im * (1 - mask_batch[0])
    
93.          input_im = np.array(input_im_data + np.ones_like(x_im) * mask_batch[0] * 255, dtype=np.uint8)
    
94.          output_im = np.concatenate((x_im,input_im,sample),axis=1)
    
95.          #print(output_im.shape)
    
96.          cv2.imwrite('./output/pass_id:{}.jpg'.format(pass_id), cv2.cvtColor(output_im, cv2.COLOR_RGB2BGR))
    
97.          # 保存模型
    
98.          save_pretrain_model_path = 'models/'
    
99.          # 建立保持模型文件目錄
    
100.          #os.makedirs(save_pretrain_model_path)
    
101.          fluid.io.save_params(executor=exe, dirname=save_pretrain_model_path, main_program = dg_program)

結果展現

項目總結

整個訓練過程,花了9小時左右,共訓練了100次補全網絡+45次補全網絡和鑑別網絡。

Image Completion Result 中的 Input 是挖洞後輸入補全網絡的圖像,在 Output 看到, Input 圖像上挖的洞已經被補上了,這說明如今的訓練結果已經能在必定程度上補全圖像的缺失部分了。因爲本項目實現時在硬件及時間方面受限,所以對原文中的方法進行了簡化,訓練方法和數據樣本處理較原論文有所調整作了調整,沒法達到原論文效果,但相較於原做者兩個月的訓練時間對比,這樣的訓練方式也是可取的。

如想到達到原論文的精準的小夥伴,能夠在本項目基礎上修改訓練策略~在此附上原論文訓練程序圖

本項目使用了飛槳開源深度學習框架,在AI Studio上完成了數據處理、模型訓練、效果預測等整個工做過程,很是感謝AI Studio給咱們提供的GPU在線訓練環境,對於在深度學習道路上硬件條件上不足的學生來講簡直是很是大的幫助。

若是你對這個小實驗感興趣,也能夠本身來嘗試一下,整個項目包括數據集與相關代碼已公開在AI Studio上,歡迎小夥伴們Fork。

https://aistudio.baidu.com/ai...

如在使用過程當中有問題,可加入飛槳官方QQ羣進行交流:1108045677。

若是您想詳細瞭解更多飛槳的相關內容,請參閱如下文檔。

相關文章
相關標籤/搜索