變分自編碼器(variational autoencoder, VAE)是一種生成模型,訓練模型分爲編碼器和解碼器兩部分。python
編碼器將輸入樣本映射爲某個低維分佈,這個低維分佈一般是不一樣維度之間相互獨立的多元高斯分佈,所以編碼器的輸出爲這個高斯分佈的均值與對數方差(由於方差老是大於0,爲了將它映射到$(-\infty,\infty)$,因此加了對數)。在編碼器的分佈中抽樣後,解碼器作的事是將從這個低維抽樣從新解碼,生成與輸入樣本類似的數據。數據能夠是圖像、文字、音頻等。數組
VAE模型的結構不難理解,關鍵在於它的損失函數的定義。咱們要讓解碼器的輸出與編碼器的輸入儘可能類似,這個損失能夠由這兩者之間的二元交叉熵(binary crossentropy)來定義。可是僅由這個做爲最終的目標函數是不夠的。在這樣的目標函數下,不斷的梯度降低,會使編碼器在不一樣輸入下的輸出均值之間差異愈來愈大,而輸出方差則會不斷地趨向於0,也就是對數方差趨向於負無窮。由於只有這樣纔會使從生成分佈獲取的抽樣更加明確,從而讓解碼器能生成與輸入數據更接近的數據,以使損失變得更小。可是這就與生成器的初衷有悖了,生成器的初衷其實是爲了生成更多「全新」的數據,而不是爲了生成與輸入數據「更像」的數據。因此,咱們還要再給目標函數加上編碼器生成分佈的「正則化損失」:生成分佈與標準正態分佈之間的KL散度(相對熵)。讓生成分佈不至於「太極端、太肯定」,從而讓不一樣輸入數據的生成分佈之間有交叉 。因而解碼器經過這些交叉的「緩衝帶」上的抽樣,可以生成「中間數據」,產生意想不到的效果。dom
詳細的分析請看:變分自編碼器VAE:原來是這麼一回事 - 知乎ide
如下使用Keras實現VAE生成圖像,數據集是MNIST。函數
編碼器將MNIST的數字圖像轉換爲2維的正態分佈均值與對數方差。簡單堆疊卷積層與全鏈接層便可,代碼以下:學習
#%%編碼器 import numpy as np import keras from keras import layers,Model,models,utils from keras import backend as K from keras.datasets import mnist img_shape = (28,28,1) latent_dim = 2 input_img = layers.Input(shape=img_shape) x = layers.Conv2D(32,3,padding='same',activation='relu')(input_img) x = layers.Conv2D(64,3,padding='same',activation='relu',strides=2)(x) x = layers.Conv2D(64,3,padding='same',activation='relu')(x) x = layers.Conv2D(64,3,padding='same',activation='relu')(x) inter_shape = K.int_shape(x) x = layers.Flatten()(x) x = layers.Dense(32,activation='relu')(x) encode_mean = layers.Dense(2,name = 'encode_mean')(x) #分佈均值 encode_log_var = layers.Dense(2,name = 'encode_logvar')(x) #分佈對數方差 encoder = Model(input_img,[encode_mean,encode_log_var],name = 'encoder')
解碼器接受2維向量,將這個向量「解碼」爲圖像。一樣也是簡單的堆疊卷積層、逆卷積層與全鏈接層便可,代碼以下:測試
#%%解碼器 input_code = layers.Input(shape=[2]) x = layers.Dense(np.prod(inter_shape[1:]),activation='relu')(input_code) x = layers.Reshape(target_shape=inter_shape[1:])(x) x = layers.Conv2DTranspose(32,3,padding='same',activation='relu',strides=2)(x) x = layers.Conv2D(1,3,padding='same',activation='sigmoid')(x) decoder = Model(input_code,x,name = 'decoder')
整個待訓練模型包括編碼器、抽樣層、解碼器。中間的抽樣操做在獲取編碼器傳出的均值與方差後,經過一個自定義的lambda層來實現。這個抽樣是先從標準正態分佈中抽樣,再經過乘生成分佈的標準差,加上均值來得到。所以這個操做並不會把反向傳播中斷,能夠將編碼器與解碼器的張量流鏈接起來。ui
定義好模型後是損失的定義,如前面所說,最終損失(目標函數)是生成圖像與原圖像之間的二元交叉熵和生成分佈的正則化的平均值。使用add_loss方法來添加模型的損失,這是新出的方法,比寫wrap函數、重寫layer類(《python深度學習》書)來定義方便多了。編碼
代碼以下:spa
#%%總體待訓練模型 def sampling(arg): mean = arg[0] logvar = arg[1] epsilon = K.random_normal(shape=K.shape(mean),mean=0.,stddev=1.) #從標準正態分佈中抽樣 return mean + K.exp(0.5*logvar) * epsilon #獲取生成分佈的抽樣 input_img = layers.Input(shape=img_shape,name = 'img_input') code_mean, code_log_var = encoder(input_img) #獲取生成分佈的均值與方差 x = layers.Lambda(sampling,name = 'sampling')([code_mean, code_log_var]) x = decoder(x) training_model = Model(input_img,x,name = 'training_model') decode_loss = keras.metrics.binary_crossentropy(K.flatten(input_img), K.flatten(x)) kl_loss = -5e-4*K.mean(1+code_log_var-K.square(code_mean)-K.exp(code_log_var)) training_model.add_loss(K.mean(decode_loss+kl_loss)) #新出的方法,方便得很 training_model.compile(optimizer='rmsprop')
由於損失函數並無定義真實數據與預測數據直接的損失,所以fit方法只需傳入輸入便可(不用輸出)。代碼以下:
#%%讀取數據集訓練 (x_train,y_train),(x_test,y_test) = mnist.load_data() x_train = x_train.astype('float32')/255 x_train = x_train[:,:,:,np.newaxis] training_model.fit( x_train, batch_size=512, epochs=100, validation_data=(x_train[:2],None))
使用scipy.stats中的norm.ppf方法在(0.01,0.99)生成20*20個解碼器輸入,這個方法相似在標準正態分佈中抽樣,但並非隨機的,是正態分佈下的等機率。生成的二維點分佈以下圖:
這樣抽樣而不均勻抽樣爲了和編碼器的生成分佈契合,由於編碼器正則化後生成的分佈是靠近標準正態分佈的。而後用解碼器生成圖片,這一部分的代碼以下:
#%%測試 from scipy.stats import norm import numpy as np import matplotlib.pyplot as plt n = 20 x = y = norm.ppf(np.linspace(0.01,0.99,n)) #生成標準正態分佈數 X,Y = np.meshgrid(x,y) #造成網格 X = X.reshape([-1,1]) #數組展平 Y = Y.reshape([-1,1]) input_points = np.concatenate([X,Y],axis=-1)#鏈接爲輸入 for i in input_points: plt.scatter(i[0],i[1]) plt.show() img_size = 28 predict_img = decoder.predict(input_points) pic = np.empty([img_size*n,img_size*n,1]) for i in range(n): for j in range(n): pic[img_size*i:img_size*(i+1), img_size*j:img_size*(j+1)] = predict_img[i*n+j] plt.figure(figsize=(10,10)) plt.axis('off') pic = np.squeeze(pic) plt.imshow(pic,cmap='bone') plt.show()
生成的400張圖: