以前用過sklearn提供的劃分數據集的函數,以爲超級方便。可是在使用TensorFlow和Pytorch的時候一直找不到相似的功能,以前搜索的關鍵字都是「pytorch split dataset」之類的,可是搜出來仍是沒有我想要的。結果今天見鬼了忽然看見了這麼一個函數torch.utils.data.Subset。個人天,爲何超級開心hhhh。終於不用每次都手動劃分數據集了。html
torch.utils.data
Pytorch提供的對數據集進行操做的函數詳見:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSamplerpython
torch的這個文件包含了一些關於數據集處理的類:微信
- class torch.utils.data.Dataset: 一個抽象類, 全部其餘類的數據集類都應該是它的子類。並且其子類必須重載兩個重要的函數:len(提供數據集的大小)、getitem(支持整數索引)。
- class torch.utils.data.TensorDataset: 封裝成tensor的數據集,每個樣本都經過索引張量來得到。
- class torch.utils.data.ConcatDataset: 鏈接不一樣的數據集以構成更大的新數據集。
- class torch.utils.data.Subset(dataset, indices): 獲取指定一個索引序列對應的子數據集。
- class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 數據加載器。組合了一個數據集和採樣器,並提供關於數據的迭代器。
- torch.utils.data.random_split(dataset, lengths): 按照給定的長度將數據集劃分紅沒有重疊的新數據集組合。
- class torch.utils.data.Sampler(data_source):全部採樣的器的基類。每一個採樣器子類都須要提供 iter 方-法以方便迭代器進行索引 和一個 len方法 以方便返回迭代器的長度。
- class torch.utils.data.SequentialSampler(data_source):順序採樣樣本,始終按照同一個順序。
- class torch.utils.data.RandomSampler(data_source):無放回地隨機採樣樣本元素。
- class torch.utils.data.SubsetRandomSampler(indices):無放回地按照給定的索引列表採樣樣本元素。
- class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照給定的機率來採樣樣本。
- class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一個batch中封裝一個其餘的採樣器。
- class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):採樣器能夠約束數據加載進數據集的子集。
示例
下面Pytorch提供的劃分數據集的方法以示例的方式給出:dom
SubsetRandomSampler
... dataset = MyCustomDataset(my_path) batch_size = 16 validation_split = .2 shuffle_dataset = True random_seed= 42 # Creating data indices for training and validation splits: dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(validation_split * dataset_size)) if shuffle_dataset : np.random.seed(random_seed) np.random.shuffle(indices) train_indices, val_indices = indices[split:], indices[:split] # Creating PT data samplers and loaders: train_sampler = SubsetRandomSampler(train_indices) valid_sampler = SubsetRandomSampler(val_indices) train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler) validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler) # Usage Example: num_epochs = 10 for epoch in range(num_epochs): # Train: for batch_index, (faces, labels) in enumerate(train_loader): # ...
random_split
... train_size = int(0.8 * len(full_dataset)) test_size = len(full_dataset) - train_size train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
參考:機器學習
- How do I split a custom dataset into training and test datasets?
- PyTorch系列 (二): pytorch數據讀取
- pytorch: 自定義數據集加載
<b style="color:tomato;"></b>函數
<footer style="color:white;;background-color:rgb(24,24,24);padding:10px;border-radius:10px;"><br> <h3 style="text-align:center;color:tomato;font-size:16px;" id="autoid-2-0-0"><br> <br> <center> <span>微信公衆號:AutoML機器學習</span><br> <img src="https://ask.qcloudimg.com/draft/1215004/21ra82axnz.jpg" style="width:200px;height:200px"> </center> <b>MARSGGBO</b><b style="color:white;"><span style="font-size:25px;">♥</span>原創</b><br> <span>若有意合做或學術討論歡迎私戳聯繫~<br>郵箱:marsggbo@foxmail.com</span> <b style="color:white;"><br> 2019-3-8<p></p> </b><p><b style="color:white;"></b><br> </p></h3><br> </footer>學習