實現一個定製的 Dataset 類
Dataset 類是 PyTorch 圖像數據集中最爲重要的一個類,也是 PyTorch 中全部數據集加載類中應該繼承的父類。其中,父類的兩個私有成員函數必須被重載。python
- getitem(self, index) # 支持數據集索引的函數
- len(self) # 返回數據集的大小
Datasets 的框架:框架
class CustomDataset(data.Dataset): # 須要繼承 data.Dataset def __init__(self): # TODO # Initialize file path or list of file names. pass def __getitem__(self, index): # TODO # 1. 從文件中讀取指定 index 的數據(例:使用 numpy.fromfile, PIL.Image.open) # 2. 預處理讀取的數據(例:torchvision.Transform) # 3. 返回數據對(例:圖像和對應標籤) pass def __len__(self): # TODO # You should change 0 to the total size of your dataset. return 0
舉例:dom
class MyDataset(Dataset): """ root: 圖像存放地址根路徑 augment:是否須要圖像加強 """ def __init__(self, root, augment=None): # 這個 list 存放全部圖像的地址 self.image_files = np.array([ x.path for x in os.scandir(root) if x.name.endswith(".jpg") or x.name.endswith(".png") or x.name.endswith(".JPG") ]) self.augment = augment def __getitem__(self, index): if self.augment: image = open_image(self.image_files[index]) # 這裏的 open_image 是讀取圖像的函數,能夠用 PIL 或者 OpenCV 等庫進行讀取 image = self.augment(image) # 這裏對圖像進行了數據加強 return to_tensor(image) # PyTorch 中獲得的圖像必須是 tensor else: image = open_image(self.image_files[index]) return to_tensor(image)
下面是官方 MNIST 的例子:函數
class MNIST(data.Dataset): """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset. Args: root (string): Root directory of dataset where ``processed/training.pt`` and ``processed/test.pt`` exist. train (bool, optional): If True, creates dataset from ``training.pt``, otherwise from ``test.pt``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ urls = [ 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', ] raw_folder = 'raw' processed_folder = 'processed' training_file = 'training.pt' test_file = 'test.pt' classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] class_to_idx = {_class: i for i, _class in enumerate(classes)} @property def targets(self): if self.train: return self.train_labels else: return self.test_labels def __init__(self, root, train=True, transform=None, target_transform=None, download=False): self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set if download: self.download() if not self._check_exists(): raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') if self.train: self.train_data, self.train_labels = torch.load( os.path.join(self.root, self.processed_folder, self.training_file)) else: self.test_data, self.test_labels = torch.load( os.path.join(self.root, self.processed_folder, self.test_file)) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ if self.train: img, target = self.train_data[index], self.train_labels[index] else: img, target = self.test_data[index], self.test_labels[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img.numpy(), mode='L') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): if self.train: return len(self.train_data) else: return len(self.test_data) def _check_exists(self): return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) def download(self): """Download the MNIST data if it doesn't exist in processed_folder already.""" from six.moves import urllib import gzip if self._check_exists(): return # download files try: os.makedirs(os.path.join(self.root, self.raw_folder)) os.makedirs(os.path.join(self.root, self.processed_folder)) except OSError as e: if e.errno == errno.EEXIST: pass else: raise for url in self.urls: print('Downloading ' + url) data = urllib.request.urlopen(url) filename = url.rpartition('/')[2] file_path = os.path.join(self.root, self.raw_folder, filename) with open(file_path, 'wb') as f: f.write(data.read()) with open(file_path.replace('.gz', ''), 'wb') as out_f, \ gzip.GzipFile(file_path) as zip_f: out_f.write(zip_f.read()) os.unlink(file_path) # process and save as torch files print('Processing...') training_set = ( read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')), read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')) ) test_set = ( read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')), read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')) ) with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: torch.save(training_set, f) with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: torch.save(test_set, f) print('Done!') def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) tmp = 'train' if self.train is True else 'test' fmt_str += ' Split: {}\n'.format(tmp) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str