torch.manual_seed(1) # reproducible # 假數據 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1) y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) def save(): # 建網絡 net1 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) loss_func = torch.nn.MSELoss() # 訓練 for t in range(100): prediction = net1(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step() torch.save(net1, 'net.pkl') # 保存整個網絡 torch.save(net.state_dict(), 'net_params.pkl') # 只保存網絡中的參數(速度快,佔內存少)
def restore_net(): # restore entire net1 to net2 net2 = torch.load('net.pkl') prediction = net2(x) # 只提取網絡參數 def restore_params(): # 新建 net3 net3 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) # 將保存的參數複製到 net3 net3.load_state_dict(torch.load('net_params.pkl')) prediction = net3(x) # 保存 net1 (1. 整個網絡, 2. 只有參數) save() # 提取整個網絡 restore_net() # 提取網絡參數, 複製到新網絡 restore_params()
DataLoader是torch給你用來包裝你的數據的工具。因此要將本身的(numpy array或其餘)數據形式轉換成Tensor, 而後再放進這個包裝器中。使用DataLoader的好處就是幫你有效地迭代數據。python
import torch import torch.utils.data as Data torch.manual_seed(1) # reproducible BATCH_SIZE = 5 # 批訓練的數據個數 x = torch.linspace(1, 10, 10) # x data (torch tensor) y = torch.linspace(10, 1, 10) # y data (torch tensor) # 先轉換成 torch 能識別的 Dataset torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) # 把 dataset 放入 DataLoader loader = Data.DataLoader( dataset=torch_dataset, # torch TensorDataset format batch_size=BATCH_SIZE, # mini batch size shuffle=True, # 要不要打亂數據 (打亂比較好) num_workers=2, # 多線程來讀數據 ) for epoch in range(3): # 訓練全部!整套!數據 3 次 for step, (batch_x, batch_y) in enumerate(loader): # 每一步 loader 釋放一小批數據用來學習 # 假設這裏就是你訓練的地方... # 打出來一些數據 print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy()) """ Epoch: 0 | Step: 0 | batch x: [ 6. 7. 2. 3. 1.] | batch y: [ 5. 4. 9. 8. 10.] Epoch: 0 | Step: 1 | batch x: [ 9. 10. 4. 8. 5.] | batch y: [ 2. 1. 7. 3. 6.] Epoch: 1 | Step: 0 | batch x: [ 3. 4. 2. 9. 10.] | batch y: [ 8. 7. 9. 2. 1.] Epoch: 1 | Step: 1 | batch x: [ 1. 7. 8. 5. 6.] | batch y: [ 10. 4. 3. 6. 5.] Epoch: 2 | Step: 0 | batch x: [ 3. 9. 2. 6. 7.] | batch y: [ 8. 2. 9. 5. 4.] Epoch: 2 | Step: 1 | batch x: [ 10. 4. 8. 1. 5.] | batch y: [ 1. 7. 3. 10. 6.] """
當數據最後不足batch時,就會返回這個epoch中剩下的數據。網絡