[pytorch] 自定義激活函數中的注意事項

如何在pytorch中使用自定義的激活函數?html

若是自定義的激活函數是可導的,那麼能夠直接寫一個python function來定義並調用,由於pytorch的autograd會自動對其求導。python

若是自定義的激活函數不是可導的,好比相似於ReLU的分段可導的函數,須要寫一個繼承torch.autograd.Function的類,並自行定義forward和backward的過程app

在pytorch中提供了定義新的autograd function的tutorial: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html, tutorial以ReLU爲例介紹了在forward, backward中須要自行定義的內容。dom

 1 import torch
 2 
 3 
 4 class MyReLU(torch.autograd.Function):
 5     """
 6     We can implement our own custom autograd Functions by subclassing
 7     torch.autograd.Function and implementing the forward and backward passes
 8     which operate on Tensors.
 9     """
10 
11     @staticmethod
12     def forward(ctx, input):
13         """
14         In the forward pass we receive a Tensor containing the input and return
15         a Tensor containing the output. ctx is a context object that can be used
16         to stash information for backward computation. You can cache arbitrary
17         objects for use in the backward pass using the ctx.save_for_backward method.
18         """
19         ctx.save_for_backward(input)
20         return input.clamp(min=0)
21 
22     @staticmethod
23     def backward(ctx, grad_output):
24         """
25         In the backward pass we receive a Tensor containing the gradient of the loss
26         with respect to the output, and we need to compute the gradient of the loss
27         with respect to the input.
28         """
29         input, = ctx.saved_tensors
30         grad_input = grad_output.clone()
31         grad_input[input < 0] = 0
32         return grad_input
33 
34 
35 dtype = torch.float
36 device = torch.device("cpu")
37 # device = torch.device("cuda:0") # Uncomment this to run on GPU
38 
39 # N is batch size; D_in is input dimension;
40 # H is hidden dimension; D_out is output dimension.
41 N, D_in, H, D_out = 64, 1000, 100, 10
42 
43 # Create random Tensors to hold input and outputs.
44 x = torch.randn(N, D_in, device=device, dtype=dtype)
45 y = torch.randn(N, D_out, device=device, dtype=dtype)
46 
47 # Create random Tensors for weights.
48 w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
49 w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)
50 
51 learning_rate = 1e-6
52 for t in range(500):
53     # To apply our Function, we use Function.apply method. We alias this as 'relu'.
54     relu = MyReLU.apply
55 
56     # Forward pass: compute predicted y using operations; we compute
57     # ReLU using our custom autograd operation.
58     y_pred = relu(x.mm(w1)).mm(w2)
59 
60     # Compute and print loss
61     loss = (y_pred - y).pow(2).sum()
62     print(t, loss.item())
63 
64     # Use autograd to compute the backward pass.
65     loss.backward()
66 
67     # Update weights using gradient descent
68     with torch.no_grad():
69         w1 -= learning_rate * w1.grad
70         w2 -= learning_rate * w2.grad
71 
72         # Manually zero the gradients after updating weights
73         w1.grad.zero_()
74         w2.grad.zero_()

可是若是定義ReLU函數時,沒有使用以上正確的方法,而是直接自定義的函數,會出現什麼問題呢?ide

這裏對比了使用以上MyReLU和自定義函數:no_back的實驗結果。函數

1 def no_back(x):
2     return x * (x > 0).float()

代碼:ui

N, D_in, H, D_out = 2, 3, 4, 5

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
origin_w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
origin_w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-3

def myReLU(func, x, y, origin_w1, origin_w2, learning_rate,N = 2, D_in = 3, H = 4, D_out = 5):
    w1 = deepcopy(origin_w1)
    w2 = deepcopy(origin_w2)
    for t in range(5):
        # Forward pass: compute predicted y using operations; we compute
        # ReLU using our custom autograd operation.
        y_pred = func(x.mm(w1)).mm(w2)

        # Compute and print loss
        loss = (y_pred - y).pow(2).sum()
        print("------", t, loss.item(), "------------")

        # Use autograd to compute the backward pass.
        loss.backward()

        # Update weights using gradient descent
        with torch.no_grad():
            print('w1 = ')
            print(w1)
            print('---------------------')
            print("x.mm(w1) = ")
            print(x.mm(w1))
            print('---------------------')
            print('func(x.mm(w1))')
            print(func(x.mm(w1)))
            print('---------------------')
            print("w1.grad:", w1.grad)
            # print("w2.grad:",w2.grad)
            print('---------------------')

            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad

            # Manually zero the gradients after updating weights
            w1.grad.zero_()
            w2.grad.zero_()
            print('========================')
            print()


myReLU(func = MyReLU.apply, x = x, y = y, origin_w1 = origin_w1, origin_w2 = origin_w2, learning_rate = learning_rate, N = 2, D_in = 3, H = 4, D_out = 5)
print('============')
print('============')
print('============')
myReLU(func = no_back, x = x, y = y, origin_w1 = origin_w1, origin_w2 = origin_w2, learning_rate = learning_rate, N = 2, D_in = 3, H = 4, D_out = 5)

對於使用了MyReLU.apply的實驗結果爲:this

 1 ------ 0 20.18220329284668 ------------
 2 w1 = 
 3 tensor([[ 0.7070,  2.5772,  0.7987,  2.2287],
 4         [ 0.7425, -0.6309,  0.3268, -1.5072],
 5         [ 0.6930, -2.6128,  0.1949,  0.8819]], requires_grad=True)
 6 ---------------------
 7 x.mm(w1) = 
 8 tensor([[-0.9788,  1.0135, -0.4164,  1.8834],
 9         [-0.7692, -1.8556, -0.7085, -0.9849]])
10 ---------------------
11 func(x.mm(w1))
12 tensor([[0.0000, 1.0135, 0.0000, 1.8834],
13         [0.0000, 0.0000, 0.0000, 0.0000]])
14 ---------------------
15 w1.grad: tensor([[  0.0000,   0.0499,   0.0000,   0.1881],
16         [  0.0000,  -4.4962,   0.0000, -16.9378],
17         [  0.0000,  -0.2401,   0.0000,  -0.9043]])
18 ---------------------
19 ========================
20 
21 ------ 1 19.546737670898438 ------------
22 w1 = 
23 tensor([[ 0.7070,  2.5772,  0.7987,  2.2285],
24         [ 0.7425, -0.6265,  0.3268, -1.4903],
25         [ 0.6930, -2.6126,  0.1949,  0.8828]], requires_grad=True)
26 ---------------------
27 x.mm(w1) = 
28 tensor([[-0.9788,  1.0078, -0.4164,  1.8618],
29         [-0.7692, -1.8574, -0.7085, -0.9915]])
30 ---------------------
31 func(x.mm(w1))
32 tensor([[0.0000, 1.0078, 0.0000, 1.8618],
33         [0.0000, 0.0000, 0.0000, 0.0000]])
34 ---------------------
35 w1.grad: tensor([[  0.0000,   0.0483,   0.0000,   0.1827],
36         [  0.0000,  -4.3446,   0.0000, -16.4493],
37         [  0.0000,  -0.2320,   0.0000,  -0.8782]])
38 ---------------------
39 ========================
40 
41 ------ 2 18.94647789001465 ------------
42 w1 = 
43 tensor([[ 0.7070,  2.5771,  0.7987,  2.2283],
44         [ 0.7425, -0.6221,  0.3268, -1.4738],
45         [ 0.6930, -2.6123,  0.1949,  0.8837]], requires_grad=True)
46 ---------------------
47 x.mm(w1) = 
48 tensor([[-0.9788,  1.0023, -0.4164,  1.8409],
49         [-0.7692, -1.8591, -0.7085, -0.9978]])
50 ---------------------
51 func(x.mm(w1))
52 tensor([[0.0000, 1.0023, 0.0000, 1.8409],
53         [0.0000, 0.0000, 0.0000, 0.0000]])
54 ---------------------
55 w1.grad: tensor([[  0.0000,   0.0467,   0.0000,   0.1775],
56         [  0.0000,  -4.2009,   0.0000, -15.9835],
57         [  0.0000,  -0.2243,   0.0000,  -0.8534]])
58 ---------------------
59 ========================
60 
61 ------ 3 18.378826141357422 ------------
62 w1 = 
63 tensor([[ 0.7070,  2.5771,  0.7987,  2.2281],
64         [ 0.7425, -0.6179,  0.3268, -1.4578],
65         [ 0.6930, -2.6121,  0.1949,  0.8846]], requires_grad=True)
66 ---------------------
67 x.mm(w1) = 
68 tensor([[-0.9788,  0.9969, -0.4164,  1.8206],
69         [-0.7692, -1.8607, -0.7085, -1.0040]])
70 ---------------------
71 func(x.mm(w1))
72 tensor([[0.0000, 0.9969, 0.0000, 1.8206],
73         [0.0000, 0.0000, 0.0000, 0.0000]])
74 ---------------------
75 w1.grad: tensor([[  0.0000,   0.0451,   0.0000,   0.1726],
76         [  0.0000,  -4.0644,   0.0000, -15.5391],
77         [  0.0000,  -0.2170,   0.0000,  -0.8296]])
78 ---------------------
79 ========================
80 
81 ------ 4 17.841421127319336 ------------
82 w1 = 
83 tensor([[ 0.7070,  2.5770,  0.7987,  2.2280],
84         [ 0.7425, -0.6138,  0.3268, -1.4423],
85         [ 0.6930, -2.6119,  0.1949,  0.8854]], requires_grad=True)
86 ---------------------
87 x.mm(w1) = 
88 tensor([[-0.9788,  0.9918, -0.4164,  1.8008],
89         [-0.7692, -1.8623, -0.7085, -1.0100]])
90 ---------------------
91 func(x.mm(w1))
92 tensor([[0.0000, 0.9918, 0.0000, 1.8008],
93         [0.0000, 0.0000, 0.0000, 0.0000]])
94 ---------------------
95 w1.grad: tensor([[  0.0000,   0.0437,   0.0000,   0.1679],
96         [  0.0000,  -3.9346,   0.0000, -15.1145],
97         [  0.0000,  -0.2101,   0.0000,  -0.8070]])
98 ---------------------
99 ========================
View Code

對於使用了no_back的實驗結果爲:編碼

 1 ------ 0 20.18220329284668 ------------
 2 w1 = 
 3 tensor([[ 0.7070,  2.5772,  0.7987,  2.2287],
 4         [ 0.7425, -0.6309,  0.3268, -1.5072],
 5         [ 0.6930, -2.6128,  0.1949,  0.8819]], requires_grad=True)
 6 ---------------------
 7 x.mm(w1) = 
 8 tensor([[-0.9788,  1.0135, -0.4164,  1.8834],
 9         [-0.7692, -1.8556, -0.7085, -0.9849]])
10 ---------------------
11 func(x.mm(w1))
12 tensor([[-0.0000, 1.0135, -0.0000, 1.8834],
13         [-0.0000, -0.0000, -0.0000, -0.0000]])
14 ---------------------
15 w1.grad: tensor([[  0.0000,   0.0499,   0.0000,   0.1881],
16         [  0.0000,  -4.4962,   0.0000, -16.9378],
17         [  0.0000,  -0.2401,   0.0000,  -0.9043]])
18 ---------------------
19 ========================
20 
21 ------ 1 19.546737670898438 ------------
22 w1 = 
23 tensor([[ 0.7070,  2.5772,  0.7987,  2.2285],
24         [ 0.7425, -0.6265,  0.3268, -1.4903],
25         [ 0.6930, -2.6126,  0.1949,  0.8828]], requires_grad=True)
26 ---------------------
27 x.mm(w1) = 
28 tensor([[-0.9788,  1.0078, -0.4164,  1.8618],
29         [-0.7692, -1.8574, -0.7085, -0.9915]])
30 ---------------------
31 func(x.mm(w1))
32 tensor([[-0.0000, 1.0078, -0.0000, 1.8618],
33         [-0.0000, -0.0000, -0.0000, -0.0000]])
34 ---------------------
35 w1.grad: tensor([[  0.0000,   0.0483,   0.0000,   0.1827],
36         [  0.0000,  -4.3446,   0.0000, -16.4493],
37         [  0.0000,  -0.2320,   0.0000,  -0.8782]])
38 ---------------------
39 ========================
40 
41 ------ 2 18.94647789001465 ------------
42 w1 = 
43 tensor([[ 0.7070,  2.5771,  0.7987,  2.2283],
44         [ 0.7425, -0.6221,  0.3268, -1.4738],
45         [ 0.6930, -2.6123,  0.1949,  0.8837]], requires_grad=True)
46 ---------------------
47 x.mm(w1) = 
48 tensor([[-0.9788,  1.0023, -0.4164,  1.8409],
49         [-0.7692, -1.8591, -0.7085, -0.9978]])
50 ---------------------
51 func(x.mm(w1))
52 tensor([[-0.0000, 1.0023, -0.0000, 1.8409],
53         [-0.0000, -0.0000, -0.0000, -0.0000]])
54 ---------------------
55 w1.grad: tensor([[  0.0000,   0.0467,   0.0000,   0.1775],
56         [  0.0000,  -4.2009,   0.0000, -15.9835],
57         [  0.0000,  -0.2243,   0.0000,  -0.8534]])
58 ---------------------
59 ========================
60 
61 ------ 3 18.378826141357422 ------------
62 w1 = 
63 tensor([[ 0.7070,  2.5771,  0.7987,  2.2281],
64         [ 0.7425, -0.6179,  0.3268, -1.4578],
65         [ 0.6930, -2.6121,  0.1949,  0.8846]], requires_grad=True)
66 ---------------------
67 x.mm(w1) = 
68 tensor([[-0.9788,  0.9969, -0.4164,  1.8206],
69         [-0.7692, -1.8607, -0.7085, -1.0040]])
70 ---------------------
71 func(x.mm(w1))
72 tensor([[-0.0000, 0.9969, -0.0000, 1.8206],
73         [-0.0000, -0.0000, -0.0000, -0.0000]])
74 ---------------------
75 w1.grad: tensor([[  0.0000,   0.0451,   0.0000,   0.1726],
76         [  0.0000,  -4.0644,   0.0000, -15.5391],
77         [  0.0000,  -0.2170,   0.0000,  -0.8296]])
78 ---------------------
79 ========================
80 
81 ------ 4 17.841421127319336 ------------
82 w1 = 
83 tensor([[ 0.7070,  2.5770,  0.7987,  2.2280],
84         [ 0.7425, -0.6138,  0.3268, -1.4423],
85         [ 0.6930, -2.6119,  0.1949,  0.8854]], requires_grad=True)
86 ---------------------
87 x.mm(w1) = 
88 tensor([[-0.9788,  0.9918, -0.4164,  1.8008],
89         [-0.7692, -1.8623, -0.7085, -1.0100]])
90 ---------------------
91 func(x.mm(w1))
92 tensor([[-0.0000, 0.9918, -0.0000, 1.8008],
93         [-0.0000, -0.0000, -0.0000, -0.0000]])
94 ---------------------
95 w1.grad: tensor([[  0.0000,   0.0437,   0.0000,   0.1679],
96         [  0.0000,  -3.9346,   0.0000, -15.1145],
97         [  0.0000,  -0.2101,   0.0000,  -0.8070]])
98 ---------------------
99 ========================
View Code

對比發現,兩者在梯度大小及更新的數值、loss大小等都是數值相等的,這是否說明對於不可導函數,直接定義函數也能夠取得和正肯定義前向後向過程相同的結果呢?spa

應當注意到一個問題,那就是在MyReLU.apply的實驗結果中,出現數值爲0的地方,顯示爲0.0000,而在no_back的實驗結果中,出現數值爲0的地方,顯示爲-0.0000;

0.0000與-0.0000有什麼區別呢?

參考stack overflow中的解答:https://stackoverflow.com/questions/4083401/negative-zero-in-python

和wikipedia中對於signed zero的介紹:https://en.wikipedia.org/wiki/Signed_zero

在python中兩者是顯然不一樣的對象,可是在數值比較時,兩者的值顯示爲相等。

-0.0 == +0.0 == 0

在Python 中使它們數值相等的設定,是在儘可能避免爲code引入bug.

>>> a = 3.4
>>> b =4.4
>>> c = -0.0
>>> d = +0.0
>>> a*c
-0.0
>>> b*d
0.0
>>> a*c == b*d
True
>>> 

雖然看起來,它們在使用中並無什麼區別,可是在計算機內部對它們的編碼表示並不相同。

在對於整數的1+7位元的符號數值表示法中,負零是用二進制代碼10000000表示的。在8位元二進制反碼中,負零是用二進制代碼11111111表示,但補碼表示法則沒有負零的概念。在IEEE 754二進制浮點數算術標準中,指數和尾數爲零、符號位元爲一的數就是負零。

IBM的普通十進制算數編碼規範中,運用十進制來表示浮點數。這裏負零被表示爲指數爲編碼內任意合法數值、全部係數均爲零、符號位元爲一的數。

 ~(wikipedia)

在數值分析中,也常將-0看作從負數區間無限趨近於0的值,將+0看作從正數區間無限趨近於0的值,兩者在數值上近似相等,但在某些操做中卻可能產生不一樣的結果。

好比 divmod,會沿用數值的sign:

>>> divmod(-0.0,100)
(-0.0, 0.0)
>>> divmod(+0.0,100)
(0.0, 0.0)

好比 atan2, (介紹詳見https://en.wikipedia.org/wiki/Atan2)

 

atan2(+0, +0) = +0;  

atan2(+0, −0) = +π;  ( 當y是位於y軸正半軸,無限趨近於0的值;x是位於x軸負半軸,無限趨近於0的值,=> 能夠看作是在第二象限中位於x軸負半軸的一點 => $\theta夾角爲$\pi$)

atan2(−0, +0) = −0;  ( 能夠看作是在第四象限中位於x軸正半軸的一點 => $\theta夾角爲-0)

atan2(−0, −0) = −π.

用代碼驗證:

>>> math.atan2(0.0, 0.0) == math.atan2(-0.0, 0.0)
True 
>>> math.atan2(0.0, -0.0) == math.atan2(-0.0, -0.0)
False

因此,儘管在上面自定義激活函數時,將不可導函數強行加入到pytorch的autograd中運算,數值結果相同;可是注意到-0.0000的出現是程序有bug的提示,嚴謹考慮仍須要規範定義,如MyReLU。

相關文章
相關標籤/搜索