以前用過sklearn提供的劃分數據集的函數,以爲超級方便。可是在使用TensorFlow和Pytorch的時候一直找不到相似的功能,以前搜索的關鍵字都是「pytorch split dataset」之類的,可是搜出來仍是沒有我想要的。結果今天見鬼了忽然看見了這麼一個函數torch.utils.data.Subset。個人天,爲何超級開心hhhh。終於不用每次都手動劃分數據集了。html
Pytorch提供的對數據集進行操做的函數詳見:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSamplerpython
torch的這個文件包含了一些關於數據集處理的類:dom
下面Pytorch提供的劃分數據集的方法以示例的方式給出:函數
... dataset = MyCustomDataset(my_path) batch_size = 16 validation_split = .2 shuffle_dataset = True random_seed= 42 # Creating data indices for training and validation splits: dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(validation_split * dataset_size)) if shuffle_dataset : np.random.seed(random_seed) np.random.shuffle(indices) train_indices, val_indices = indices[split:], indices[:split] # Creating PT data samplers and loaders: train_sampler = SubsetRandomSampler(train_indices) valid_sampler = SubsetRandomSampler(val_indices) train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler) validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler) # Usage Example: num_epochs = 10 for epoch in range(num_epochs): # Train: for batch_index, (faces, labels) in enumerate(train_loader): # ...
... train_size = int(0.8 * len(full_dataset)) test_size = len(full_dataset) - train_size train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
參考:spa