這是一個官網的例子:torch.nn入門。html
通常而言,咱們會根據本身的數據需求繼承Dataset(from torch.utils.data import Dataset, DataLoader)重寫數據讀取函數。或者利用TensorDataset更加簡潔實現讀取數據。python
抑或利用 torchvision裏面的ImageFolder
也可管理數據。這幾種方法已經能夠實現數據讀取了,而DataLoader的做用是更加全面管理批量數據:app
下面進入正題,MNIST數據利用CNN時須要轉換爲二維數據,因此須要對初始的線性數據進行轉換。通常,能夠讀取先行數據後在模型中進行view來實現:ide
class Lambda(nn.Module): def __init__(self, func): super().__init__() self.func = func def forward(self, x): return self.func(x) def preprocess(x): return x.view(-1, 1, 28, 28) model = nn.Sequential( Lambda(preprocess), nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.AvgPool2d(4), Lambda(lambda x: x.view(x.size(0), -1)), )
文中給出另外一種解決方案:重寫DateLoader:將數據處理移到生成器裏面函數
def get_data(train_ds, valid_ds, bs): return ( DataLoader(train_ds, batch_size=bs, shuffle=True), DataLoader(valid_ds, batch_size=bs * 2), ) def preprocess(x, y): return x.view(-1, 1, 28, 28), y class WrappedDataLoader: def __init__(self, dl, func): self.dl = dl self.func = func def __len__(self): return len(self.dl) def __iter__(self): batches = iter(self.dl) for b in batches: yield (self.func(*b)) train_dl, valid_dl = get_data(train_ds, valid_ds, bs) train_dl = WrappedDataLoader(train_dl, preprocess) valid_dl = WrappedDataLoader(valid_dl, preprocess)
模型就能夠寫成這樣:spa
model = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1), Lambda(lambda x: x.view(x.size(0), -1)), )