深度有趣 | 20 CycleGAN性別轉換

簡介

介紹可用於實現多種非配對圖像翻譯任務的CycleGAN模型,並完成性別轉換任務html

原理

和pix2pix不一樣,CycleGAN不須要嚴格配對的圖片,只須要兩類(domain)便可,例如一個文件夾都是蘋果圖片,另外一個文件夾都是橘子圖片python

使用A和B兩類圖片,就能夠實現A到B的翻譯和B到A的翻譯git

論文官方網站上提供了詳細的例子和介紹,https://junyanz.github.io/CycleGAN/,例如蘋果和橘子、馬和斑馬、夏天和冬天、照片和藝術做品等github

CycleGAN非配對圖像翻譯示例

以及論文的官方Github項目,https://github.com/junyanz/CycleGAN,使用PyTorch實現網絡

CycleGAN由兩個生成器G和F,以及兩個判別器Dx和Dy組成app

CycleGAN模型結構

G接受真的X並輸出假的Y,即完成X到Y的翻譯;F接受真的Y並輸出假的X,即完成Y到X的翻譯;Dx接受真假X並進行判別,Dy接受真假Y並進行判別dom

CycleGAN的損失函數和標準GAN差很少,只是寫兩套而已ide

$$ L_{GAN}(G,D_Y,X,Y)=\mathbb{E}{y\sim p_y}[\log D_Y(y)]+\mathbb{E}{x\sim p_x}[\log(1-D_Y(G(x)))] $$函數

$$ L_{GAN}(F,D_X,Y,X)=\mathbb{E}{x\sim p_x}[\log D_X(x)]+\mathbb{E}{y\sim p_y}[\log(1-D_X(F(y)))] $$網站

除此以外,爲了不mode collasp問題,CycleGAN還考慮了循環一致損失(Cycle Consistency Loss)

$$ L_{cyc}(G,F)=\mathbb{E}_{x\sim p_x}[\left | F(G(x))-x \right |1]+\mathbb{E}{y\sim p_y}[\left | G(F(y))-y \right |_1] $$

所以CycleGAN的總損失以下,G、F、Dx、Dy分別須要min、max其中的部分損失項

$$ L(G,F,D_X,D_Y)=L_{GAN}(G,D_Y,X,Y)+L_{GAN}(F,D_X,Y,X)+\lambda L_{cyc}(G,F) $$

實現

在論文的具體實現中,使用了兩個tricks

  • 使用Least-Square Loss即最小平方偏差代替標準的GAN損失
  • 以G爲例,維護一個歷史假Y圖片集合,例如50張。每次G生成假Y以後將其加到集合中,再從集合中隨機地取出一張假Y,和一張真Y一塊兒輸入給判別器進行判別。這樣一來,假Y集合表明了G根據X生成Y的平均能力,使得訓練更加穩定

使用如下項目訓練CycleGAN模型,https://github.com/vanhuyz/CycleGAN-TensorFlow,主要包括幾個代碼:

  • build_data.py:將圖片數據整理爲tfrecords文件
  • ops.py:定義了一些小的網絡模塊
  • generator.py:生成器的定義
  • discriminator.py:判別器的定義
  • model.py:使用生成器和判別器定義CycleGAN
  • train.py:訓練模型的代碼
  • export_graph.py:將訓練好的模型打包成.pd文件
  • inference.py:使用打包好的.pb文件翻譯圖片,即便用模型進行推斷

生成器和判別器結構以下,若是感興趣能夠進一步閱讀項目源碼

CycleGAN模型細節

性別轉換

使用CelebA中的男性圖片和女性圖片,訓練一個實現性別轉換的CycleGAN

將CelebA數據集中的圖片處理成256*256大小,並按照性別保存至male和female兩個文件夾,分別包含84434張男性圖片和118165張女性圖片

# -*- coding: utf-8 -*-

from imageio import imread, imsave
import cv2
import glob, os
from tqdm import tqdm

data_dir = 'data'
male_dir = 'data/male'
female_dir = 'data/female'

if not os.path.exists(data_dir):
    os.mkdir(data_dir)
if not os.path.exists(male_dir):
    os.mkdir(male_dir)
if not os.path.exists(female_dir):
    os.mkdir(female_dir)

WIDTH = 256
HEIGHT = 256

def read_process_save(read_path, save_path):
    image = imread(read_path)
    h = image.shape[0]
    w = image.shape[1]
    if h > w:
        image = image[h // 2 - w // 2: h // 2 + w // 2, :, :]
    else:
        image = image[:, w // 2 - h // 2: w // 2 + h // 2, :]    
    image = cv2.resize(image, (WIDTH, HEIGHT))
    imsave(save_path, image)

target = 'Male'
with open('list_attr_celeba.txt', 'r') as fr:
    lines = fr.readlines()
    all_tags = lines[0].strip('\n').split()
    for i in tqdm(range(1, len(lines))):
        line = lines[i].strip('\n').split()
        if int(line[all_tags.index(target) + 1]) == 1:
            read_process_save(os.path.join('celeba', line[0]), os.path.join(male_dir, line[0])) # 男
        else:
            read_process_save(os.path.join('celeba', line[0]), os.path.join(female_dir, line[0])) # 女

使用build_data.py將圖片轉換成tfrecords格式

python CycleGAN-TensorFlow/build_data.py --X_input_dir data/male/ --Y_input_dir data/female/ --X_output_file data/male.tfrecords --Y_output_file data/female.tfrecords

使用train.py訓練CycleGAN模型

python CycleGAN-TensorFlow/train.py --X data/male.tfrecords --Y data/female.tfrecords --image_size 256

訓練開始後,會生成checkpoints文件夾,並根據當前日期和時間生成一個子文件夾,例如20180507-0231,其中包括用於顯示tensorboard的events.out.tfevents文件,以及和模型相關的一些文件

使用tensorboard查看模型訓練細節,運行如下命令後訪問6006端口便可

tensorboard --logdir=checkpoints/20180507-0231

如下是迭代185870次以後,tensorboard的IMAGES頁面

CycleGAN模型訓練tensorboard細節

模型訓練沒有迭代次數限制,因此感受效果不錯或者迭代次數差很少了,即可以終止訓練

使用export_graph.py將模型打包成.pb文件,生成的文件在pretrained文件夾中

python CycleGAN-TensorFlow/export_graph.py --checkpoint_dir checkpoints/20180507-0231/ --XtoY_model male2female.pb --YtoX_model female2male.pb --image_size 256

經過inference.py使用模型處理圖片

python CycleGAN-TensorFlow/inference.py --model pretrained/male2female.pb --input Trump.jpg --output Trump_female.jpg --image_size 256
python CycleGAN-TensorFlow/inference.py --model pretrained/female2male.pb --input Hillary.jpg --output Hillary_male.jpg --image_size 256

在代碼中使用模型處理多張圖片

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
from model import CycleGAN
from imageio import imread, imsave
import glob
import os

image_file = 'face.jpg'
W = 256
result = np.zeros((4 * W, 5 * W, 3))

for gender in ['male', 'female']:
    if gender == 'male':
        images = glob.glob('../faces/male/*.jpg')
        model = '../pretrained/male2female.pb'
        r = 0
    else:
        images = glob.glob('../faces/female/*.jpg')
        model = '../pretrained/female2male.pb'
        r = 2

    graph = tf.Graph()
    with graph.as_default():
        graph_def = tf.GraphDef()
        with tf.gfile.FastGFile(model, 'rb') as model_file:
            graph_def.ParseFromString(model_file.read())
            tf.import_graph_def(graph_def, name='')

        with tf.Session(graph=graph) as sess:
            input_tensor = graph.get_tensor_by_name('input_image:0')
            output_tensor = graph.get_tensor_by_name('output_image:0')

            for i, image in enumerate(images):
                image = imread(image)
                output = sess.run(output_tensor, feed_dict={input_tensor: image})

                with open(image_file, 'wb') as f:
                    f.write(output)

                output = imread(image_file)
                maxv = np.max(output)
                minv = np.min(output)
                output = ((output - minv) / (maxv - minv) * 255).astype(np.uint8)

                result[r * W: (r + 1) * W, i * W: (i + 1) * W, :] = image
                result[(r + 1) * W: (r + 2) * W, i * W: (i + 1) * W, :] = output

os.remove(image_file)
imsave('CycleGAN性別轉換結果.jpg', result)

CycleGAN性別轉換結果

視頻性別轉換

對一段視頻,識別每一幀可能包含的人臉,檢測人臉對應的性別,並使用CycleGAN完成性別的雙向轉換

使用如下項目實現性別的檢測,https://github.com/yu4u/age-gender-estimation,經過Keras訓練模型,能夠檢測出人臉的性別和年齡

舉個例子,使用OpenCV獲取攝像頭圖片,經過dlib檢測人臉,並獲得每個檢測結果對應的年齡和性別

# -*- coding: utf-8 -*-

from wide_resnet import WideResNet
import numpy as np
import cv2
import dlib

depth = 16
width = 8
img_size = 64
model = WideResNet(img_size, depth=depth, k=width)()
model.load_weights('weights.hdf5')

def draw_label(image, point, label, font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1, thickness=2):
    size = cv2.getTextSize(label, font, font_scale, thickness)[0]
    x, y = point
    cv2.rectangle(image, (x, y - size[1]), (x + size[0], y), (255, 0, 0), cv2.FILLED)
    cv2.putText(image, label, point, font, font_scale, (255, 255, 255), thickness)

detector = dlib.get_frontal_face_detector()
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)

while True:
    ret, image_np = cap.read()
    image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
    img_h = image_np.shape[0]
    img_w = image_np.shape[1]

    detected = detector(image_np, 1)
    faces = []

    if len(detected) > 0:
        for i, d in enumerate(detected):
            x0, y0, x1, y1, w, h = d.left(), d.top(), d.right(), d.bottom(), d.width(), d.height()
            cv2.rectangle(image_np, (x0, y0), (x1, y1), (255, 0, 0), 2)

            x0 = max(int(x0 - 0.25 * w), 0)
            y0 = max(int(y0 - 0.45 * h), 0)
            x1 = min(int(x1 + 0.25 * w), img_w - 1)
            y1 = min(int(y1 + 0.05 * h), img_h - 1)
            w = x1 - x0
            h = y1 - y0
            if w > h:
                x0 = x0 + w // 2 - h // 2
                w = h
                x1 = x0 + w
            else:
                y0 = y0 + h // 2 - w // 2
                h = w
                y1 = y0 + h
            faces.append(cv2.resize(image_np[y0: y1, x0: x1, :], (img_size, img_size)))

        faces = np.array(faces)
        results = model.predict(faces)
        predicted_genders = results[0]
        ages = np.arange(0, 101).reshape(101, 1)
        predicted_ages = results[1].dot(ages).flatten()

        for i, d in enumerate(detected):
            label = '{}, {}'.format(int(predicted_ages[i]), 'F' if predicted_genders[i][0] > 0.5 else 'M')
            draw_label(image_np, (d.left(), d.top()), label)

    cv2.imshow('gender and age', cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))

    if cv2.waitKey(25) & 0xFF == ord('q'):
        cap.release()
        cv2.destroyAllWindows()
        break

將以上項目和CycleGAN應用於視頻的雙向性別轉換,首先提取出視頻中的人臉,記錄人臉出現的幀數、位置以及對應的性別,視頻共830幀,檢測出721張人臉

# -*- coding: utf-8 -*-

from wide_resnet import WideResNet
import numpy as np
import cv2
import dlib
import pickle

depth = 16
width = 8
img_size = 64
model = WideResNet(img_size, depth=depth, k=width)()
model.load_weights('weights.hdf5')

detector = dlib.get_frontal_face_detector()
cap = cv2.VideoCapture('../friends.mp4')

pos = []
frame_id = -1

while cap.isOpened():
    ret, image_np = cap.read()
    frame_id += 1
    if len((np.array(image_np)).shape) == 0:
        break

    image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
    img_h = image_np.shape[0]
    img_w = image_np.shape[1]
    detected = detector(image_np, 1)

    if len(detected) > 0:
        for d in detected:
            x0, y0, x1, y1, w, h = d.left(), d.top(), d.right(), d.bottom(), d.width(), d.height()
            x0 = max(int(x0 - 0.25 * w), 0)
            y0 = max(int(y0 - 0.45 * h), 0)
            x1 = min(int(x1 + 0.25 * w), img_w - 1)
            y1 = min(int(y1 + 0.05 * h), img_h - 1)
            w = x1 - x0
            h = y1 - y0
            if w > h:
                x0 = x0 + w // 2 - h // 2
                w = h
                x1 = x0 + w
            else:
                y0 = y0 + h // 2 - w // 2
                h = w
                y1 = y0 + h
            
            face = cv2.resize(image_np[y0: y1, x0: x1, :], (img_size, img_size))
            result = model.predict(np.array([face]))
            pred_gender = result[0][0][0]

            if pred_gender > 0.5:
                pos.append([frame_id, y0, y1, x0, x1, h, w, 'F'])
            else:
                pos.append([frame_id, y0, y1, x0, x1, h, w, 'M'])

print(frame_id + 1, len(pos))

with open('../pos.pkl', 'wb') as fw:
    pickle.dump(pos, fw)
    
cap.release()
cv2.destroyAllWindows()

再使用CycleGAN,將原視頻中出現的人臉轉換成相反的性別,並寫入新的視頻文件

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
from model import CycleGAN
from imageio import imread
import os
import cv2
import pickle
from tqdm import tqdm

with open('../pos.pkl', 'rb') as fr:
    pos = pickle.load(fr)

cap = cv2.VideoCapture('../friends.mp4')
ret, image_np = cap.read()
out = cv2.VideoWriter('../output.mp4', -1, cap.get(cv2.CAP_PROP_FPS), (image_np.shape[1], image_np.shape[0]))

frames = []
while cap.isOpened():
    ret, image_np = cap.read()
    if len((np.array(image_np)).shape) == 0:
        break
    frames.append(cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB))

image_size = 256
image_file = 'face.jpg'
for gender in ['M', 'F']:
    if gender == 'M':
        model = '../pretrained/male2female.pb'
    else:
        model = '../pretrained/female2male.pb'

    graph = tf.Graph()
    with graph.as_default():
        graph_def = tf.GraphDef()
        with tf.gfile.FastGFile(model, 'rb') as model_file:
            graph_def.ParseFromString(model_file.read())
            tf.import_graph_def(graph_def, name='')

        with tf.Session(graph=graph) as sess:
            input_tensor = graph.get_tensor_by_name('input_image:0')
            output_tensor = graph.get_tensor_by_name('output_image:0')

        for i in tqdm(range(len(pos))):
            fid, y0, y1, x0, x1, h, w, g = pos[i]
            if g == gender:
                face = cv2.resize(frames[fid - 1][y0: y1, x0: x1, :], (image_size, image_size))
                output_face = sess.run(output_tensor, feed_dict={input_tensor: face})

                with open(image_file, 'wb') as f:
                    f.write(output_face)

                output_face = imread(image_file)
                maxv = np.max(output_face)
                minv = np.min(output_face)
                output_face = ((output_face - minv) / (maxv - minv) * 255).astype(np.uint8)

                output_face = cv2.resize(output_face, (w, h))
                frames[fid - 1][y0: y1, x0: x1, :] = output_face

for frame in frames:
    out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
            
os.remove(image_file)
cap.release()
out.release()
cv2.destroyAllWindows()

生成的視頻文件只有圖像、沒有聲音,可使用ffmpeg進一步處理

若是沒有ffmpeg則下載並安裝,http://www.ffmpeg.org/download.html

進入命令行,從原始視頻中提取音頻

ffmpeg -i friends.mp4 -f mp3 -vn sound.mp3

將提取的音頻和生成的視頻合成在一塊兒

ffmpeg -i output.mp4 -i sound.mp3 combine.mp4

其餘

項目還提供了四個訓練好的模型,https://github.com/vanhuyz/CycleGAN-TensorFlow/releases,包括蘋果到橘子、橘子到蘋果、馬到斑馬、斑馬到馬,若是感興趣能夠嘗試一下

用CycleGAN不只能夠完成兩類圖片之間的轉換,也能夠實現兩個物體之間的轉換,例如將一我的翻譯成另外一我的

能夠考慮從一部電影中提取出兩個角色對應的圖片,訓練CycleGAN以後,便可將一我的翻譯成另外一我的

還有一些比較大膽的嘗試,提升駕駛技術:用GAN去除(愛情)動做片中的馬賽克和衣服

參考

視頻講解課程

深度有趣(一)

相關文章
相關標籤/搜索