python pytorch numpy DNN 線性迴歸模型

一、直接奉獻代碼,後期有入門更新,以前一直在學的是TensorFlow,

import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.pyplot as plt import numpy as np x_data = np.arange(-2*np.pi,2*np.pi,0.1).reshape(-1,1) y_data = np.sin(x_data).reshape(-1,1) x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1)  # 將1維的數據轉換爲2維數據 # y = x.pow(2) + 0.2 * torch.rand(x.size())
y = torch.cos(x) # 將tensor置入Variable中 
x, y = Variable(torch.from_numpy(x_data)).float(), Variable(torch.from_numpy(y_data)).float() print(x.shape,y.shape) # plt.scatter(x.data.numpy(), y.data.numpy()) # plt.show()

# 定義一個構建神經網絡的類 
class Net(torch.nn.Module):  # 繼承torch.nn.Module類
    def __init__(self): super(Net, self).__init__()  # 得到Net類的超類(父類)的構造方法
        # 定義神經網絡的每層結構形式
        # 各個層的信息都是Net類對象的屬性
        self.hidden = torch.nn.Linear(1, 10)  # 隱藏層線性輸出
        self.centre_1 = torch.nn.Linear(10,20) self.predict = torch.nn.Linear(20, 1)  # 輸出層線性輸出

    # 將各層的神經元搭建成完整的神經網絡的前向通路
    def forward(self, x): x = F.tanh(self.hidden(x))  # 對隱藏層的輸出進行relu激活
        x_1 = F.tanh(self.centre_1(x)) x =F.tanh(self.predict(x_1)) return x # 定義神經網絡
 net = Net() print(net)  # 打印輸出net的結構

# 定義優化器和損失函數 
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)  # 傳入網絡參數和學習率
loss_function = torch.nn.MSELoss()  # 最小均方偏差
acc = lambda y1,y2: np.sqrt(np.sum(y1**2+y2**2)/len(y1)) # 神經網絡訓練過程 
plt.ion()  # 動態學習過程展現 
plt.show() for t in range(100): prediction = net(x)  # 把數據x餵給net,輸出預測值
    loss = loss_function(prediction, y)  # 計算二者的偏差,要注意兩個參數的順序
    optimizer.zero_grad()  # 清空上一步的更新參數值
    loss.backward()  # 偏差反相傳播,計算新的更新參數值
    optimizer.step()  # 將計算獲得的更新值賦給net.parameters()

    # 可視化訓練過程
    if (t + 1) % 2 == 0: plt.cla() plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2) aucc = acc(prediction.data.numpy(),y.data.numpy()) print("loss={} aucc={}".format(loss.data.numpy(),aucc)) plt.text(-4.5, 1, 'echo=%sL=%.4f acc=%s' % (t+1,loss.data.numpy(),aucc), fontdict={'size': 15, 'color': 'red'}) plt.pause(0.1) print("訓練結束") plt.ioff() plt.show()

相關文章
相關標籤/搜索