PyTorch筆記之 Dataset 和 Dataloader

1、簡介

在 PyTorch 中,咱們的數據集每每會用一個類去表示,在訓練時用 Dataloader 產生一個 batch 的數據html

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-pygit

好比官方例子中對 CIFAR10 圖像數據集進行分類,就有用到這樣的操做,具體代碼以下所示github

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

簡單說,用 一個類 抽象地表示數據集,而 Dataloader 做爲迭代器,每次產生一個 batch 大小的數據,節省內存dom

2、Dataset

Dataset 是 PyTorch 中用來表示數據集的一個抽象類,咱們的數據集能夠用這個類來表示,至少覆寫下面兩個方法便可函數

這返回數據前能夠進行適當的數據處理,好比將原文用一串數字序列表示測試

  • __len__:數據集大小
  • __getitem__:實現這個方法後,能夠經過下標的方式( dataset[i] )的來取得第 $i$ 個數據

下面咱們來爲編寫一個類表示一個情感二分類數據集,繼續用蘇神整理的數據集url

https://github.com/bojone/bert4keras/tree/master/examples/datasetsspa

數據集沒有表頭,只有2列,一列是評論(文本),另外一列是標籤,以製表符進行分隔.net

from torch.utils.data import Dataset, DataLoader import pandas as pd class SentimentDataset(Dataset): def __init__(self, path_to_file): self.dataset = pd.read_csv(path_to_file, sep="\t", names=["text", "label"]) def __len__(self): return len(self.dataset) def __getitem__(self, idx): text = self.dataset.loc[idx, "text"] label = self.dataset.loc[idx, "label"] sample = {"text": text, "label": label} return sample

3、Dataloader

3.1基本使用

Dataloader 就是一個迭代器,最基本的使用就是傳入一個 Dataset 對象,它就會根據參數 batch_size 的值生成一個 batch 的數據code

if __name__ == "__main__": sentiment_dataset = SentimentDataset("sentiment.test.data") sentiment_dataloader = DataLoader(sentiment_dataset, batch_size=4, shuffle=True, num_workers=2) for idx, batch_samples in enumerate(sentiment_dataloader): text_batchs, text_labels = batch_samples["text"], batch_samples["label"] print(text_batchs)

 3.2Sampler

PyTorch 提供了 Sampler 模塊,用來對數據進行採樣,能夠在 DataLoader 的經過 sampler 參數調用

通常咱們的加載訓練集的 dataloader ,shuffle參數都會設置爲True ,這時候使用了一個默認的採樣器——RandomSampler

當 shuffle 設置爲 False 時,默認使用的是 SequencetialSampler,其實就是按順序取出數據集中的元素

在 PyTorch 中默認實現瞭如下 Sampler,若是咱們要使用別的 Sampler, shuffle 要設置爲 False

  • SequentialSampler
  • RandomSampler
  • WeightedSampler
  • SubsetRandomSampler

SubsetRandomSampler 經常使用來將數據集劃分爲訓練集和測試集,好比這裏就訓練集和測試集按7:3 進行分割

n_train = len(sentiment_train_set) split = n_train // 3 
indices
= list(range(n_train)) train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
train_loader
= DataLoader(sentiment_train_set, sampler=train_sampler, shuffle=False) valid_loader = DataLoader(sentiment_train_set, sampler=valid_sampler, shuffle=False)

具體推薦下面的博文,講得挺詳細的

一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關係

https://www.cnblogs.com/marsggbo/p/11541054.html

Pytorch Sampler詳解

https://www.cnblogs.com/marsggbo/p/11541054.html

3.3collate_fn

能夠用來進行一些數據處理,好比在文本任務中,通常因爲文本長度不一致,咱們須要進行截斷或者填充。對於圖片,咱們則但願它們有一樣的尺寸

我麼能夠編寫一個函數,而後用這個參數調用它,下面是一個簡單的例子,咱們把文本截斷成只有10個字符

def truncate(data_list): """傳進一個batch_size大小的數據"""
  for data in data_list: text = data["text"] data["text"]=text[:10] return data_list test_loader = DataLoader(sentiment_train_set, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=truncate)

咱們能夠看看返回的內容是否已經通過截斷了

for i in test_loader: print(i) break

這時候返回的是一個列表而不是字典了,其中一個 batch 的返回結果以下,咱們能夠看到這裏一個樣本放在了一個字典中

[{'text': '看了一個通宵,實在是', 'label': 1}, 。。。, {'text': '看了攜程的其餘用戶評', 'label': 0}]

下面是沒有使用 collate_fn 的返回結果,它會將數據和標籤分開,存放在一塊兒,以下所示,

{

'text':['3月1號訂的,3月15號還沒到貨 客服天天說下個工做日能到貨已經連續5天了 我無語。想早點兒看這本書的人仍是去陶寶或卓越上訂吧,尤爲是廣東省的朋友.噹噹送貨太沒保證了.',。。。, '很是純樸的故事,但包含了主人公坎坷的一輩子,活着就是痛苦,不得不佩服生命的韌性'],

'label': tensor([1, 。。。, 0])

}

相關文章
相關標籤/搜索