pytorch hook使用

因爲pytorch會自動捨棄圖計算的中間結果,因此想要獲取這些數值就須要使用鉤子函數。html

鉤子函數包括Variable的鉤子和nn.Module鉤子,用法類似。python

import torch from torch.autograd import Variable grad_list = [] grad_listx = [] def print_grad(grad): grad_list.append(grad) def print_gradx(grad): grad_listx.append(grad) x = Variable(torch.randn(2, 1), requires_grad=True) y = x*x + 2 z = torch.mean(torch.pow(y, 2)) lr = 1e-3 y.register_hook(print_grad) x.register_hook(print_gradx) z.backward() x.data -= lr * x.grad.data print("x.grad.data-------------") print(x.grad.data) print("y-------------") print(grad_list) print("x-------------") print(grad_listx)

- 輸出: 記錄了y的梯度,而後x.data=記錄x的梯度app

/opt/conda/bin/python2.7 /root/rjw/pytorch_test/pytorch_exe03.py x.grad.data-------------

 32.3585
 14.8162 [torch.FloatTensor of size 2x1] y------------- [Variable containing: 7.1379
 4.5970 [torch.FloatTensor of size 2x1] ] x------------- [Variable containing: 32.3585
 14.8162 [torch.FloatTensor of size 2x1] ] Process finished with exit code 0

register_forward_hook & register_backward_hook

  • 這兩個函數的功能相似於variable函數的register_hook,可在module前向傳播或反向傳播時註冊鉤子。每次前向傳播執行結束後會執行鉤子函數(hook)。前向傳播的鉤子函數具備以下形式:hook(module, input, output) -> None,而反向傳播則具備以下形式:hook(module, grad_input, grad_output) -> Tensor or None
  • 鉤子函數不該修改輸入和輸出,而且在使用後應及時刪除,以免每次都運行鉤子增長運行負載。鉤子函數主要用在獲取某些中間結果的情景,如中間某一層的輸出或某一層的梯度。這些結果本應寫在forward函數中,但若是在forward函數中專門加上這些處理,可能會使處理邏輯比較複雜,這時候使用鉤子技術就更合適一些。下面考慮一種場景,有一個預訓練好的模型,須要提取模型的某一層(不是最後一層)的輸出做爲特徵進行分類,但又不但願修改其原有的模型定義文件,這時就能夠利用鉤子函數。
  • PyTorch』第十六彈_hook技術
相關文章
相關標籤/搜索