批數據訓練

把要訓練的數據轉換成數據集,而後把數據集放到loader中加載,能夠有效的幫助你迭代數據。多線程

DataLoader 是 torch 給你用來包裝你的數據的工具. 因此你要講本身的 (numpy array 或其餘) 數據形式裝換成 Tensor, 而後再放進這個包裝器中. 工具

代碼:學習

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible

BATCH_SIZE = 5      # 批訓練的數據個數

x = torch.linspace(1, 10, 10)       # x data (torch tensor)
y = torch.linspace(10, 1, 10)       # y data (torch tensor)

# 先轉換成 torch 能識別的 Dataset
torch_dataset = Data.TensorDataset( x,y)
# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=True,               # 要不要打亂數據 (打亂比較好)
    num_workers=2,              # 多線程來讀數據
)
if __name__ == '__main__':#多進程的時候使用
    for epoch in range(3):   # 訓練全部!整套!數據 3for step, (batch_x, batch_y) in enumerate(loader):  # 每一步 loader 釋放一小批數據用來學習
            # 假設這裏就是你訓練的地方...

            # 打出來一些數據
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())

若是BATCH_SIZE=8,則第一個batch輸入數據8個,第二個batch輸入數據2個(不夠八個)spa

BATCH_SIZE = 8      # 批訓練的數據個數

...

for ...:
    for ...:
        ...
        print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
              batch_x.numpy(), '| batch y: ', batch_y.numpy())
"""
Epoch:  0 | Step:  0 | batch x:  [  6.   7.   2.   3.   1.   9.  10.   4.] | batch y:  [  5.   4.   9.   8.  10.   2.   1.   7.]
Epoch:  0 | Step:  1 | batch x:  [ 8.  5.] | batch y:  [ 3.  6.]
Epoch:  1 | Step:  0 | batch x:  [  3.   4.   2.   9.  10.   1.   7.   8.] | batch y:  [  8.   7.   9.   2.   1.  10.   4.   3.]
Epoch:  1 | Step:  1 | batch x:  [ 5.  6.] | batch y:  [ 6.  5.]
Epoch:  2 | Step:  0 | batch x:  [  3.   9.   2.   6.   7.  10.   4.   8.] | batch y:  [ 8.  2.  9.  5.  4.  1.  7.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 1.  5.] | batch y:  [ 10.   6.]
"""
相關文章
相關標籤/搜索