搭建網絡的步驟大體爲如下:python
1.準備數據算法
2. 定義網絡結構model網絡
3. 定義損失函數
4. 定義優化算法 optimizer
5. 訓練
5.1 準備好tensor形式的輸入數據和標籤(可選)
5.2 前向傳播計算網絡輸出output和計算損失函數loss
5.3 反向傳播更新參數
如下三句話一句也不能少:
5.3.1 optimizer.zero_grad() 將上次迭代計算的梯度值清0
5.3.2 loss.backward() 反向傳播,計算梯度值
5.3.3 optimizer.step() 更新權值參數
5.4 保存訓練集上的loss和驗證集上的loss以及準確率以及打印訓練信息。(可選
6. 圖示訓練過程當中loss和accuracy的變化狀況(可選)
7. 在測試集上測試app
代碼註釋都寫的很詳細 函數
1 import torch 2 import torch.nn.functional as F 3 import matplotlib.pyplot as plt 4 5 # 1.準備數據 generate data 6 x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1) 7 print(x.shape) 8 y=x*x+0.2*torch.rand(x.size()) 9 #顯示數據散點圖 10 plt.scatter(x.data.numpy(),y.data.numpy()) 11 12 # 2.定義網絡結構 build net 13 class Net(torch.nn.Module): 14 #n_feature:輸入特徵個數 n_hidden:隱藏層個數 n_output:輸出層個數 15 def __init__(self,n_feature,n_hidden,n_output): 16 # super表示繼承Net的父類,並同時初始化父類的參數 17 super(Net,self).__init__() 18 # nn.Linear表明線性層 表明y=w*x+b 其中w的shape爲[n_hidden,n_feature] b的shape爲[n_hidden] 19 # y=w^T*x+b 這裏w的維度是轉置前的維度 因此是反的 20 self.hidden =torch.nn.Linear(n_feature,n_hidden) 21 self.predict =torch.nn.Linear(n_hidden,n_output) 22 print(self.hidden.weight) 23 print(self.predict.weight) 24 #定義一個前向傳播過程函數 25 def forward(self, x): 26 # n_feature n_hidden n_output 27 #舉例(2,5,1) 2 5 1 28 # - ** - 29 # ** - - - ** - - 30 # - ** - - - ** 31 # ** - - - ** - - 32 # - ** - 33 # 輸入層 隱藏層 輸出層 34 x=F.relu(self.hidden(x)) 35 x=self.predict(x) 36 return x 37 # 實例化一個網絡爲net 38 net = Net(n_feature=1,n_hidden=10,n_output=1) 39 print(net) 40 # 3.定義損失函數 這裏使用均方偏差(mean square error) 41 loss_func=torch.nn.MSELoss() 42 # 4.定義優化器 這裏使用隨機梯度降低 43 optimizer=torch.optim.SGD(net.parameters(),lr=0.2) 44 #定義300遍更新 每10遍顯示一次 45 plt.ion() 46 # 5.訓練 47 for t in range(100): 48 prediction = net(x) # input x and predict based on x 49 loss = loss_func(prediction, y) # must be (1. nn output, 2. target) 50 # 5.3反向傳播三步不可少 51 optimizer.zero_grad() # clear gradients for next train 52 loss.backward() # backpropagation, compute gradients 53 optimizer.step() # apply gradients 54 55 if t % 10 == 0: 56 # plot and show learning process 57 plt.cla() 58 plt.scatter(x.data.numpy(), y.data.numpy()) 59 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 60 plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'}) 61 plt.show() 62 plt.pause(0.1) 63 64 plt.ioff()
參考:莫煩python測試