GAN生成式對抗網絡(四)——SRGAN超高分辨率圖片重構

論文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdfgit

個人實際效果

清晰度距離個人期待有距離。 顏色上面存在差距。 解決想法 增長一個顏色判別器。將顏色值反饋給生成器github

srgan論文是創建在gan基礎上的,利用gan生成式對抗網絡,將圖片重構爲高清分辨率的圖片。 github上有開源的srgan項目。因爲開源者,開發時考慮的問題更豐富,技巧更爲高明,致使其代碼都比較難以閱讀和理解。 在爲了充分理解這個論文。這裏結合論文,開源代碼,和本身的理解從新寫了個srgan高清分辨率模型。網絡

##GAN原理 在一個不斷提升判斷能力的判斷器的持續反饋下,不斷改善生成器的生成參數,直到生成器生成的結果可以經過判斷器的判斷。(見本博客其餘文章)app

##SRGAN用到的模塊,及其關係 損失值,根據的這個關係結構計算的。 注意:vgg19是使用已經訓練好的模型,這裏只是拿來提取特徵使用,dom

對於生成器,根據三個運算結果數據,進行隨機梯度的優化調整 ①斷定器生成數據的鑑定結果 ②vgg19的特徵比較狀況 ③生成圖形與理想圖形的mse差距ide

論文中,生成器和判別器的模型圖

生成器結構爲:一層卷積,16層殘差卷積,再將第一層卷積結果+16層殘差結,卷積+2倍反捲積,卷積+2倍反捲積,tanh縮放,產生生成結果。 判別器結構爲:8層卷積+reshape,全鏈接。(論文中,用了兩層。我這裏只用了一層全鏈接,參數量太大,我6G 的gpu內存不夠用) vgg19結構:在vgg19的第四層,返回獲取到的特徵結果,進行MSE對比 注意:BN處理,leaky relu等等處理技巧函數

代碼解釋

import numpy as np
import os
import tensorlayer as tl
import tensorflow as tf

#獲取vgg9.npy中vgg19的參數, 
vgg19_npy_path = "./vgg19.npy"
if not os.path.isfile(vgg19_npy_path):
    print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")
    exit()
npz = np.load(vgg19_npy_path, encoding='latin1').item()
w_params = []
b_params = []
for val in sorted(npz.items()):
    W = np.asarray(val[1][0])
    b = np.asarray(val[1][1])
    # print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
    w_params.append(W, )
    b_params.extend(b)


#tensorlayer加載圖片時,用於處理圖片。隨機獲取圖片中 192*192的矩陣, 內存不足時,能夠優化這裏
def crop_sub_imgs_fn(x, is_random=True):
    x = tl.prepro.crop(x, wrg=192, hrg=192, is_random=is_random)
    x = x / (255. / 2.)
    x = x - 1.
    return x
#resize矩陣 內存不足時,能夠優化這裏
def downsample_fn(x):
    x = tl.prepro.imresize(x, size=[48, 48], interp='bicubic', mode=None)
    x = x / (255. / 2.)
    x = x - 1.
    return x

# 參數
config = {
    "epoch": 5,
}

# 內存不夠時,能夠減少這個
batch_size = 10 


class SRGAN(object):
    def __init__(self):
        # with tf.device('/gpu:0'):
        #佔位變量,存儲須要重構的圖片
        self.x = tf.placeholder(tf.float32, shape=[batch_size, 48, 48, 3], name='train_bechanged')
        #佔位變量,存儲須要學習的理想中的圖片
        self.y = tf.placeholder(tf.float32, shape=[batch_size, 192, 192, 3], name='train_target')
        self.init_fake_y = self.generator(self.x)  # 預訓練時生成的假照片
        self.fake_y = self.generator(self.x, reuse=True)  # 所有訓練時生成的假照片

         #佔位變量,存儲須要重構的測試圖片
        self.test_x = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='test_generator')
        #佔位變量,存儲重構後的測試圖片
        self.test_fake_y = self.generator(self.test_x, reuse=True)  # 生成的假照片

        #佔位變量,將生成圖片resize
        self.fake_y_vgg = tf.image.resize_images(
            self.fake_y, size=[224, 224], method=0,
            align_corners=False)
         #佔位變量,將理想圖片resize
        self.real_y_vgg = tf.image.resize_images(
            self.y, size=[224, 224], method=0,
            align_corners=False)
        #提取僞造圖片的特徵
        self.fake_y_feature = self.vgg19(self.fake_y_vgg)  # 假照片的特徵值
        #提取理想圖片的特徵
        self.real_y_feature = self.vgg19(self.real_y_vgg, reuse=True)  # 真照片的特徵值

        # self.pre_dis_logits = self.discriminator(self.fake_y)  # 判別器生成的預測照片的判別值
        self.fake_dis_logits = self.discriminator(self.fake_y, reuse=False)  # 判別器生成的假照片的判別值
        self.real_dis_logits = self.discriminator(self.y, reuse=True)  # 判別器生成的假照片的判別值

        # 預訓練時,判別器的優化根據值
        self.init_mse_loss = tf.losses.mean_squared_error(self.init_fake_y, self.y)

        # 關於判別器的優化根據值
        self.D_loos = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_dis_logits,
                                                                             labels=tf.ones_like(
                                                                                 self.real_dis_logits))) + \
                      tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_dis_logits,
                                                                             labels=tf.zeros_like(
                                                                                 self.fake_dis_logits)))

        # 僞造數據判別器的判斷狀況,生成與目標圖像的差距,生成特徵與理想特徵的差距
        self.D_loos_Ge = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_dis_logits, labels=tf.ones_like( self.fake_dis_logits)))
        self.mse_loss = tf.losses.mean_squared_error(self.fake_y, self.y)
        self.loss_vgg = tf.losses.mean_squared_error(self.fake_y_feature, self.real_y_feature)

        #生成器的優化根據值,上面三個值的和
        self.G_loos = 1e-3 * self.D_loos_Ge + 2e-6 * self.loss_vgg + self.mse_loss
       
        #獲取具體條件下的更新變量集合。
        t_vars = tf.trainable_variables()
        self.g_vars = [var for var in t_vars if var.name.startswith('trainGenerator')]
        self.d_vars = [var for var in t_vars if var.name.startswith('discriminator')]



    # 生成器,16層深度殘差+1層初始的深度殘差+2次2倍反捲積+1個卷積
    def generator(slef, input, reuse=False):
        with tf.variable_scope('trainGenerator') as scope:
            if reuse:
                scope.reuse_variables()
            n = tf.layers.conv2d(input, 64, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            prellu_param = tf.get_variable('p_alpha', n.get_shape()[-1], initializer=tf.constant_initializer(0.0),
                                           dtype=tf.float32)
            n = tf.nn.relu(n) + prellu_param * (n - abs(n)) * 0.02
            # n = tf.nn.relu(n)
            temp = n
            # 開始深度殘差網絡
            for i in range(16):
                nn = tf.layers.conv2d(n, 64, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                      bias_initializer=None)
                nn = tf.layers.batch_normalization(nn, training=True)
                prellu_param = tf.get_variable('p_alpha' + str(2 * i + 1), n.get_shape()[-1],
                                               initializer=tf.constant_initializer(0.0),
                                               dtype=tf.float32)
                nn = tf.nn.relu(nn) + prellu_param * (nn - abs(nn)) * 0.02

                nn = tf.layers.conv2d(nn, 64, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                      bias_initializer=None)
                nn = tf.layers.batch_normalization(nn, training=True)
                # prellu_param = tf.get_variable('p_alpha' + str(2 * i + 2), n.get_shape()[-1],
                #                                initializer=tf.constant_initializer(0.0),
                #                                dtype=tf.float32)
                # nn = tf.nn.relu(nn) + prellu_param * (nn - abs(nn)) * 0.02
                n = nn + n

            n = tf.layers.conv2d(n, 64, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            # prellu_param = tf.get_variable('p_alpha_34', n.get_shape()[-1],
            #                                initializer=tf.constant_initializer(0.0),
            #                                dtype=tf.float32)
            # n = tf.nn.relu(n) + prellu_param * (n - abs(n)) * 0.02

            #注意這裏的temp,看論文裏面的生成器結構圖
            n = temp + n

            # 將特徵還原爲圖
            n = tf.layers.conv2d_transpose(n, 256, 3, strides=2, padding='SAME', activation=None, use_bias=True,
                                           bias_initializer=None)

            n = tf.layers.conv2d(n, 256, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.nn.relu(n)

            n = tf.layers.conv2d_transpose(n, 256, 3, strides=2, padding='SAME', activation=None, use_bias=True,
                                           bias_initializer=None)
            n = tf.layers.conv2d(n, 256, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.nn.relu(n)

            n = tf.layers.conv2d(n, 3, 1, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.nn.tanh(n)
            return n


    #判別器
    def discriminator(self, input, reuse=False):
        # input   size: 384x384
        with tf.variable_scope('discriminator') as scope:
            if reuse:
                scope.reuse_variables()
            # 1
            n = tf.layers.conv2d(input, 64, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.maximum(0.01 * n, n)
            # 2
            n = tf.layers.conv2d(n, 64, 3, strides=2, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            # 3
            n = tf.layers.conv2d(n, 128, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            # 4
            n = tf.layers.conv2d(n, 128, 3, strides=2, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            # 5
            n = tf.layers.conv2d(n, 256, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            # 6
            n = tf.layers.conv2d(n, 256, 3, strides=2, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            # 7
            n = tf.layers.conv2d(n, 512, 3, strides=1, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            # 8
            n = tf.layers.conv2d(n, 512, 3, strides=2, padding='SAME', activation=None, use_bias=True,
                                 bias_initializer=None)
            n = tf.layers.batch_normalization(n, training=True)
            n = tf.maximum(0.01 * n, n)

            flatten = tf.reshape(n, (input.get_shape()[0], -1))
            # 內存不夠,減少全連接數量
            # f = tf.layers.dense(flatten, 1024)
            # 論文裏面這裏時leaky relu,這我用的dense裏面自帶的
            f = tf.layers.dense(flatten, 1, bias_initializer=tf.contrib.layers.xavier_initializer())

            return f
    #vgg19特徵提取
    def vgg19(self, input, reuse=False):
        VGG_MEAN = [103.939, 116.779, 123.68]
        with tf.variable_scope('vgg19') as scope:
            # if reuse:
            #     scope.reuse_variables()
            # ====================
            print("build model started")
            rgb_scaled = (input + 1) * (255.0 / 2)
            # Convert RGB to BGR
            red, green, blue = tf.split(rgb_scaled, 3, 3)
            assert red.get_shape().as_list()[1:] == [224, 224, 1]
            assert green.get_shape().as_list()[1:] == [224, 224, 1]
            assert blue.get_shape().as_list()[1:] == [224, 224, 1]
            bgr = tf.concat(
                [
                    blue - VGG_MEAN[0],
                    green - VGG_MEAN[1],
                    red - VGG_MEAN[2],
                ], axis=3)
            assert bgr.get_shape().as_list()[1:] == [224, 224, 3]

            # --------------------

            n = tf.nn.conv2d(bgr, w_params[0], name='conv2_1', strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[0])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[1], name='conv2_2', strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[1])
            n = tf.nn.relu(n)
            n = tf.nn.max_pool(n, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')

            # return n

            # two
            n = tf.nn.conv2d(n, w_params[2], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[2])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[3], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[3])
            n = tf.nn.relu(n)
            n = tf.nn.max_pool(n, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')
            # three
            n = tf.nn.conv2d(n, w_params[4], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[4])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[5], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[5])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[6], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[6])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[7], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[7])
            n = tf.nn.relu(n)
            n = tf.nn.max_pool(n, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')
            # four
            n = tf.nn.conv2d(n, w_params[8], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[8])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[9], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[9])
            n = tf.nn.relu(n)

            n = tf.nn.conv2d(n, w_params[10], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[10])
            n = tf.nn.relu(n)
            n = tf.nn.conv2d(n, w_params[11], strides=(1, 1, 1, 1), padding='SAME')
            n = tf.add(n, b_params[11])
            n = tf.nn.relu(n)
            n = tf.nn.max_pool(n, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')
            return n

            # # five
            # n = tf.nn.conv2d(n, w_params[12], strides=(1, 1, 1, 1), padding='SAME')
            # n = tf.add(n, b_params[12])
            # n = tf.nn.relu(n)
            # n = tf.nn.conv2d(n, w_params[13], strides=(1, 1, 1, 1), padding='SAME')
            # n = tf.add(n, b_params[13])
            # n = tf.nn.relu(n)
            #
            # n = tf.nn.conv2d(n, w_params[14], strides=(1, 1, 1, 1), padding='SAME')
            # n = tf.add(n, b_params[14])
            # n = tf.nn.relu(n)
            # n = tf.nn.conv2d(n, w_params[15], strides=(1, 1, 1, 1), padding='SAME')
            # n = tf.add(n, b_params[15])
            # n = tf.nn.relu(n)
            # n = tf.nn.max_pool(n, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')
            # return n

            # 這裏拿特徵進行mse對比,不須要後面的全鏈接
            # flatten = tf.reshape(n, (input.get_shape()[0], -1))
            # f = tf.layers.dense(flatten, 4096)
            # f = tf.layers.dense(f, 4096)
            # f = tf.layers.dense(f, 1)
            # return n


gan = SRGAN()
G_OPTIM_init = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.init_mse_loss, var_list=gan.g_vars)
D_OPTIM = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.D_loos, var_list=gan.d_vars)
G_OPTIM = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.G_loos, var_list=gan.g_vars)

saver = tf.train.Saver(max_to_keep=3)

init = tf.global_variables_initializer()

 
#加載路徑文件夾中的訓練圖片,這裏加載的只是圖片目錄。防止內存中加載太多圖片,內存不夠   
train_hr_img_list = sorted(tl.files.load_file_list(path='F:\\theRoleOfCOde\深度學習\SRGAN_PF\gaoqing', regx='.*.png', printable=False))[:100]
#加載圖片  
train_hr_imgs = tl.vis.read_images(train_hr_img_list, path='F:\\theRoleOfCOde\深度學習\SRGAN_PF\gaoqing', n_threads=1)

#加載路徑文件夾中的測試圖片目錄
test_img_list = sorted( tl.files.load_file_list(path='F:\\theRoleOfCOde\深度學習\SRGAN_PF\SRGAN_PF\img\\test', regx='.*.png', printable=False))[ :6]
test_img = tl.vis.read_images(test_img_list, path='F:\\theRoleOfCOde\深度學習\SRGAN_PF\SRGAN_PF\img\\test', n_threads=1)



#分三種運行方式,
#pre,預訓練判別器
#restore,回覆訓練好的模型,繼續訓練


#訓練一下子,就測試一下效果。將生成的圖片矩陣,保存爲numpy矩陣
#經過工具函數,變化爲圖片查看
#第三種,從零開始訓練
with tf.Session() as sess:
    type = 'go'
    if type == 'restore':
        saver.restore(sess, "./save/nets/ckpt-0-80")
        print('---------------------恢復之前的訓練數據,繼續訓練-----------------------')
        for epoch in range(0):
            for idx in range(0, (len(train_hr_imgs) // 10), batch_size):
                # print(type(train_hr_imgs[idx:idx + batch_size]))
                b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn,
                                                      is_random=True)
                b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
                print('-------------pre_generator:' + str(epoch) + '_' + str(idx) + '----------------')
                for i in range(40):
                    init_mse_loss, _ = sess.run([gan.init_mse_loss, G_OPTIM_init],
                                                feed_dict={
                                                    gan.x: b_imgs_96,
                                                    gan.y: b_imgs_384
                                                })
                    print('init_mse_loss:' + str(init_mse_loss))
            saver.save(sess, "save/nets/better_ge.ckpt")
        for epoch in range(config["epoch"]):
            for idx in range(0, len(train_hr_imgs), batch_size):
                # print(type(train_hr_imgs[idx:idx + batch_size]))
                b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn,
                                                      is_random=True)
                b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
                print('-------------' + str(epoch) + '_' + str(idx) + '----------')
                for i in range(25):
                    loss_D, _ = sess.run([gan.D_loos, D_OPTIM],
                                         feed_dict={
                                             gan.x: b_imgs_96,
                                             gan.y: b_imgs_384
                                         })
                    loss_G, _ = sess.run([gan.G_loos, G_OPTIM],
                                         feed_dict={
                                             gan.x: b_imgs_96,
                                             gan.y: b_imgs_384
                                         })
                    print(loss_D, loss_G)
                if idx % 20 == 0:
                    saver.save(sess, "./save/nets/better_all_" + str(epoch) + "_" + str(idx) + '.ckpt')

                    _imgs = (np.asanyarray(test_img[0:1]) / (255. / 2.)) - 1
                    _imgs = _imgs[:, :, :, 0:3]
                    result_fake_y = sess.run([gan.test_fake_y], feed_dict={
                        gan.test_x: _imgs
                    })  # 生成的假照片
                    # result=sess.run(result_fake_y)
                    strpath = './preImg/result_' + str(epoch) + '_' + str(idx) + '_1.npy'
                    np.save(strpath, result_fake_y)

                    _imgs2 = (np.asanyarray(test_img[1:2]) / (255. / 2.)) - 1
                    _imgs2 = _imgs2[:, :, :, 0:3]
                    result_fake_y = sess.run([gan.test_fake_y], feed_dict={
                        gan.test_x: _imgs2
                    })  # 生成的假照片
                    # result=sess.run(result_fake_y)
                    strpath = './preImg/result_' + str(epoch) + '_' + str(idx) + '_2.npy'
                    np.save(strpath, result_fake_y)
                    # print(type(result_fake_y))
    elif type == 'pre':
        saver.restore(sess, "save/nets/better_all_1_28.ckpt")
        print('---------------------恢復訓練好的模型,開始預測-----------------------')
        for num in range(6):
            _imgs = (np.asanyarray(test_img[num:(num + 1)]) / (255. / 2.)) - 1
            print(_imgs.shape)
            _imgs = _imgs[:, :, :, 0:3]
            # time.sleep(1)
            result_fake_y = sess.run([gan.test_fake_y], feed_dict={
                gan.test_x: _imgs
            })  # 生成的假照片
            strpath = './preImg/pre_result_' + str(num) + '.npy'
            np.save(strpath, result_fake_y)
            print('ok')
    else:
        sess.run(init)
        print('---------------------開始新的訓練-----------------------')
        for epoch in range(2):
            for idx in range(0, len(train_hr_imgs), batch_size):
                # print(type(train_hr_imgs[idx:idx + batch_size]))
                b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn,
                                                      is_random=True)
                b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
                print('-------------pre_generator:' + str(epoch) + '_' + str(idx) + '----------------')
                for i in range(25):
                    init_mse_loss, _ = sess.run([gan.init_mse_loss, G_OPTIM_init],
                                                feed_dict={
                                                    gan.x: b_imgs_96,
                                                    gan.y: b_imgs_384
                                                })
                    print('init_mse_loss:' + str(init_mse_loss))
        saver.save(sess, "save/nets/cnn_mnist_basic_generator.ckpt")
        for epoch in range(config["epoch"]):
            for idx in range(0, len(train_hr_imgs), batch_size):
                # print(type(train_hr_imgs[idx:idx + batch_size]))
                b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn,
                                                      is_random=True)
                b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
                print('-------------' + str(epoch) + '_' + str(idx) + '----------')
                for i in range(25):
                    loss_D, _ = sess.run([gan.D_loos, D_OPTIM],
                                         feed_dict={
                                             gan.x: b_imgs_96,
                                             gan.y: b_imgs_384
                                         })
                    loss_G, _ = sess.run([gan.G_loos, G_OPTIM],
                                         feed_dict={
                                             gan.x: b_imgs_96,
                                             gan.y: b_imgs_384
                                         })
                    print(loss_D, loss_G)
                if idx % 20 == 0:
                    _imgs = (np.asanyarray(test_img[0:1]) / (255. / 2.)) - 1
                    _imgs = _imgs[:, :, :, 0:3]
                    result_fake_y = sess.run([gan.test_fake_y], feed_dict={
                        gan.test_x: _imgs
                    })  # 生成的假照片
                    # result=sess.run(result_fake_y)
                    strpath = './preImg/result_' + str(epoch) + '_' + str(idx) + '_1.npy'
                    np.save(strpath, result_fake_y)

                    _imgs2 = (np.asanyarray(test_img[1:2]) / (255. / 2.)) - 1
                    _imgs2 = _imgs2[:, :, :, 0:3]
                    result_fake_y = sess.run([gan.test_fake_y], feed_dict={
                        gan.test_x: _imgs2
                    })  # 生成的假照片
                    # result=sess.run(result_fake_y)
                    strpath = './preImg/result_' + str(epoch) + '_' + str(idx) + '_2.npy'
                    np.save(strpath, result_fake_y)
                    saver.save(sess, "save/nets/ckpt-" + str(epoch) + '-' + str(idx))
                    # print(type(result_fake_y))

查看效果的工具函數

將numpy矩陣轉換爲圖片工具

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

npz = np.load('../preImg/pre_result_5.npy', encoding='latin1')
print(npz.shape)
data = ((npz[0][0]) + 1) * (255. / 2.)
print(data)

new_im = Image.fromarray(data.astype(np.uint8))
new_im.show()
new_im.save('result.png')
相關文章
相關標籤/搜索