我最近在學使用Pytorch寫GAN代碼,發現有些代碼在訓練部分細節有略微不一樣,其中有的人用到了detach()函數截斷梯度流,有的人沒用detch(),取而代之的是在損失函數在反向傳播過程當中將backward(retain_graph=True),本文經過兩個 gan 的代碼,介紹它們的做用,並分析,不一樣的更新策略對程序效率的影響。html
這兩個 GAN 的實現中,有兩種不一樣的訓練策略:node
爲了減小網絡垃圾,GAN的原理網上一大堆,我這裏就不重複贅述了,想要詳細瞭解GAN原理的朋友,能夠參考我專題文章:神經網絡結構:生成式對抗網絡(GAN)。算法
須要瞭解的知識:網絡
detach():截斷node反向傳播的梯度流,將某個node變成不須要梯度的Varibale,所以當反向傳播通過這個node時,梯度就不會從這個node往前面傳播。dom
咱們直接下面進入本文正題,即,在 pytorch 中,detach 和 retain_graph 是幹什麼用的?本文將藉助三段 GAN 的實現代碼,來舉例介紹它們的做用。函數
咱們分析循環中一個 step 的代碼:post
valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真實標籤,都是1 fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假標籤,都是0 # ######################## # 訓練判別器 # # ######################## real_imgs = imgs.to(device) # 真實圖片 z = torch.randn((imgs.shape[0], 100)).to(device) # 噪聲 gen_imgs = generator(z) # 從噪聲中生成假數據 pred_gen = discriminator(gen_imgs) # 判別器對假數據的輸出 pred_real = discriminator(real_imgs) # 判別器對真數據的輸出 optimizer_D.zero_grad() # 把判別器中全部參數的梯度歸零 real_loss = adversarial_loss(pred_real, valid) # 判別器對真實樣本的損失 fake_loss = adversarial_loss(pred_gen, fake) # 判別器對假樣本的損失 d_loss = (real_loss + fake_loss) / 2 # 兩項損失相加取平均 # 下面這行代碼十分重要,將在正文着重講解 d_loss.backward(retain_graph=True) # retain_graph=True 十分重要,不然計算圖內存將會被釋放 optimizer_D.step() # 判別器參數更新 # ######################## # 訓練生成器 # # ######################## g_loss = adversarial_loss(pred_gen, valid) # 生成器的損失函數 optimizer_G.zero_grad() # 生成器參數梯度歸零 g_loss.backward() # 生成器的損失函數梯度反向傳播 optimizer_G.step() # 生成器參數更新
代碼講解ui
鑑別器的損失函數d_loss是由real_loss和fake_loss組成的,而fake_loss又是noise通過generator來的。這樣一來咱們對d_loss進行反向傳播,不只會計算discriminator 的梯度還會計算generator 的梯度(雖然這一步optimizer_D.step()只更新 discriminator 的參數),所以下面在更新generator參數時,要先將generator參數的梯度清零,避免受到discriminator loss 回傳過來的梯度影響。url
generator 的 損失在回傳時,一樣要通過 discriminator 網絡才能傳遞迴自身(系統從輸入噪聲到 Discriminator 輸出,從頭至尾只有一次前向傳播,而有兩次反向傳播,故在第一次反向傳播時,鑑別器要設置 backward(retain graph=True),保持計算圖不被釋放。由於 pytorch 默認一個計算圖只計算一次反向傳播,反向傳播後,這個計算圖的內存就會被釋放,因此用這個參數控制計算圖不被釋放。所以,在回傳梯度時,一樣也計算了一遍 discriminator 的參數梯度,只不過此次 discriminator 的參數不更新,只更新 generator 的參數,即 optimizer_G.step()。同時,咱們看到,下一個 step 首先將 discriminator 的梯度重置爲 0,就是爲了防止 generator loss 反向傳播時順帶計算的梯度對其形成影響(還有上一步 discriminator loss 回傳時累積的梯度)。spa
綜上,咱們看到,爲了完成一步參數更新,咱們進行了兩次反向傳播,第一次反向傳播爲了更新 discriminator 的參數,但多餘計算了 generator 的梯度。第二次反向傳播爲了更新 generator 的參數,可是計算了 discriminator 的梯度,所以在寫一個step,須要當即清零discriminator梯度。
若是你實在看不懂,就照着這個形式寫代碼就好了,反正形式都幫大家寫好了。
這種策略我遇到的比較多,也是先訓練鑑別器,再訓練生成器
鑑別器訓練階段,noise 從 generator 輸入,輸出 fake data,而後 detach 一下,隨着 true data 一塊兒輸入 discriminator,計算 discriminator 損失,並更新 discriminator 參數。生成器訓練階段,把沒通過 detach 的 fake data 輸入到discriminator 中,計算 generator loss,再反向傳播梯度,更新 generator 的參數。這種策略,計算了兩次 discriminator 梯度,一次 generator 梯度。感受這種比較符合先更新 discriminator 的習慣。缺點是,以前的 generator 生成的計算圖得保留着,直到 discriminator 更新完,再釋放。
valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真實標籤,都是1 fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假標籤,都是0 # ######################## # 訓練判別器 # # ######################## real_imgs = imgs.to(device) # 真實圖片 z = torch.randn((imgs.shape[0], 100)).to(device) # 噪聲 gen_imgs = generator(z) # 從噪聲中生成假數據 pred_gen = discriminator(gen_imgs.detach()) # 假數據detach(),判別器對假數據的輸出 pred_real = discriminator(real_imgs) # 判別器對真數據的輸出 optimizer_D.zero_grad() # 把判別器中全部參數的梯度歸零 real_loss = adversarial_loss(pred_real, valid) # 判別器對真實樣本的損失 fake_loss = adversarial_loss(pred_gen, fake) # 判別器對假樣本的損失 d_loss = (real_loss + fake_loss) / 2 # 兩項損失相加取平均 # 下面這行代碼十分重要,將在正文着重講解 d_loss.backward() # retain_graph=True 十分重要,不然計算圖內存將會被釋放 optimizer_D.step() # 判別器參數更新 # ######################## # 訓練生成器 # # ######################## g_loss = adversarial_loss(pred_gen, valid) # 生成器的損失函數 optimizer_G.zero_grad() # 生成器參數梯度歸零 g_loss.backward() # 生成器的損失函數梯度反向傳播 optimizer_G.step() # 生成器參數更新
咱們分析循環中一個 step 的代碼:
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 真實樣本的標籤,都是 1 fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 生成樣本的標籤,都是 0 z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 噪聲 real_imgs = Variable(imgs.type(Tensor)) # 真實圖片 # ######################## # 訓練生成器 # # ######################## optimizer_G.zero_grad() # 生成器參數梯度歸零 gen_imgs = generator(z) # 根據噪聲生成虛假樣本 g_loss = adversarial_loss(discriminator(gen_imgs), valid) # 用真實的標籤+假樣本,計算生成器損失 g_loss.backward() # 生成器梯度反向傳播,反向傳播通過了判別器,故此時判別器參數也有梯度 optimizer_G.step() # 生成器參數更新,判別器參數雖然有梯度,可是這一步不能更新判別器 # ######################## # 訓練判別器 # # ######################## optimizer_D.zero_grad() # 把生成器損失函數梯度反向傳播時,順帶計算的判別器參數梯度清空 real_loss = adversarial_loss(discriminator(real_imgs), valid) # 真樣本+真標籤:判別器損失 fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 假樣本+假標籤:判別器損失 d_loss = (real_loss + fake_loss) / 2 # 判別器總的損失函數 d_loss.backward() # 判別器損失回傳 optimizer_D.step() # 判別器參數更新
爲了更新生成器參數,用生成器的損失函數計算梯度,而後反向傳播,傳播圖中通過了判別器,根據鏈式法則,不得不順帶計算一下判別器的參數梯度,雖然在這一步不會更新判別器參數。反向傳播事後,noise 到 fake image 再到 discriminator 的輸出這個前向傳播的計算圖就被釋放掉了,後面也不會再用到。
接着更新判別器參數,此時注意到,咱們輸入判別器的是兩部分,一部分是真實數據,另外一部分是生成器的輸出,也就是假數據。注意觀察細節,在判別器前向傳播過程,輸入的假數據被 detach 了,detach 的意思是,這個數據和生成它的計算圖「脫鉤」了,即梯度傳到它那個地方就停了,再也不繼續往前傳播(實際上也不會再往前傳播了,由於 generator 的計算圖在第一次反向傳播事後就被釋放了)。所以,判別器梯度反向傳播,就到它本身身上爲止。
所以,比起第一種策略,這種策略要少計算一次 generator 的全部參數的梯度,同時,也沒必要刻意保存一次計算圖,佔用沒必要要的內存。
但須要注意的是,在第一種策略中,noise 從 generator 輸入,到 discriminator 輸出,只經歷了一次前向傳播,discriminator 端的輸出,被用了兩次,一次是計算 discriminator 的損失函數,另外一次是計算 generator 的損失函數。
而在第這種策略中,noise 從 generator 輸入,到discriminator 輸出,計算 generator 損失,回傳,這一步更新了 generator 的參數,並釋放了計算圖。下一步更新 discriminator 的參數時,generator 的輸出通過 detach 後,又經過了一遍 discriminator,至關於,generator 的輸出先後兩次經過了 discriminator ,獲得相同的輸出。顯然,這也是冗餘的。
綜上,這兩段代碼各有利弊:
第一段代碼,好處是 noise 只進行了一次前向傳播,缺點是,更新 discriminator 參數時,多計算了一次 generator 的梯度,同時,第一次更新 discriminator 須要保留計算圖,保證算 generator loss 時計算圖不被銷燬。
第三段代碼,好處是經過先更新 generator ,使更新後的前向傳播計算圖能夠放心被銷燬,所以不用保留計算圖佔用內存。同時,在更新 discriminator 的時候,也不會像上面的那段代碼,計算冗餘的 generator 的梯度。缺點是,在 discriminator 上,對 generator 的輸出算了兩次前向傳播,第二次又產生了新的計算圖(但比第一次的小)。
一個多計算了一次 generator 梯度,一個多計算一次 discriminator 前向傳播。所以,二者差異不大。若是 discriminator 比generator 複雜,那麼應該採起第一種策略,若是 discriminator 比 generator 簡單,那麼應該採起第三種策略,一般狀況下,discriminator 要比 generator 簡單,故若是效果差很少儘可能採起第三種策略。
可是第三種先更新generator,再更新 discriminator 老是給人感受怪怪得,由於 generator 的更新須要 discriminator 提供準確的 loss 和 gradient,不然豈不是在瞎更新?
可是策略三,立刻用完立刻釋放。綜合來講,仍是策略三最好,策略二其次,策略一最差(差在多計算一次 generator gradient 上,而一般多計算一次 generator gradient 的運算量比多計算一次 discriminator 前向傳播的運算量大),所以,detach 仍是頗有必要的。