因爲pytorch會自動捨棄圖計算的中間結果,因此想要獲取這些數值就須要使用鉤子函數。html
鉤子函數包括Variable的鉤子和nn.Module鉤子,用法類似。python
import torch from torch.autograd import Variable grad_list = [] grad_listx = [] def print_grad(grad): grad_list.append(grad) def print_gradx(grad): grad_listx.append(grad) x = Variable(torch.randn(2, 1), requires_grad=True) y = x*x + 2 z = torch.mean(torch.pow(y, 2)) lr = 1e-3 y.register_hook(print_grad) x.register_hook(print_gradx) z.backward() x.data -= lr * x.grad.data print("x.grad.data-------------") print(x.grad.data) print("y-------------") print(grad_list) print("x-------------") print(grad_listx)
- 輸出: 記錄了y的梯度,而後x.data=記錄x的梯度app
/opt/conda/bin/python2.7 /root/rjw/pytorch_test/pytorch_exe03.py x.grad.data------------- 32.3585 14.8162 [torch.FloatTensor of size 2x1] y------------- [Variable containing: 7.1379 4.5970 [torch.FloatTensor of size 2x1] ] x------------- [Variable containing: 32.3585 14.8162 [torch.FloatTensor of size 2x1] ] Process finished with exit code 0
register_forward_hook
& register_backward_hook
register_hook
,可在module前向傳播或反向傳播時註冊鉤子。每次前向傳播執行結束後會執行鉤子函數(hook)。前向傳播的鉤子函數具備以下形式:hook(module, input, output) -> None
,而反向傳播則具備以下形式:hook(module, grad_input, grad_output) -> Tensor or None
。