pytorch ImageFolder的覆寫

在爲數據分類訓練分類器的時候,好比貓狗分類時,咱們常常會使用pytorch的ImageFolder:html

CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)

使用可見pytorch torchvision.ImageFolder的使用web

這裏想實現的是若是想要覆寫該函數,即能使用它的特性,又能夠實現本身的功能app

首先先分析下其源代碼:dom

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']

class ImageFolder(DatasetFolder):
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """
    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
        super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
                                          transform=transform,
                                          target_transform=target_transform)
        self.imgs = self.samples

ImageFolder的代碼很簡單,主要是繼承了DatasetFolder函數

def has_file_allowed_extension(filename, extensions):
    """查看文件是不是支持的可擴展類型

    Args:
        filename (string): 文件路徑
        extensions (iterable of strings): 可擴展類型列表,即能接受的圖像文件類型

    Returns:
        bool: True if the filename ends with one of given extensions
    """
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in extensions) # 返回True或False列表


def make_dataset(dir, class_to_idx, extensions):
    """
        返回形如[(圖像路徑, 該圖像對應的類別索引值),(),...]
    """
    images = []
    dir = os.path.expanduser(dir)
    for target in sorted(class_to_idx.keys()):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)): #層層遍歷文件夾,返回當前文件夾路徑,存在的全部文件夾名,存在的全部文件名
            for fname in sorted(fnames):
                if has_file_allowed_extension(fname, extensions):查看文件是不是支持的可擴展類型,是則繼續
                    path = os.path.join(root, fname)
                    item = (path, class_to_idx[target])
                    images.append(item)

    return images

class DatasetFolder(data.Dataset):
    """A generic data loader where the samples are arranged in this way: ::

        root/class_x/xxx.ext
        root/class_x/xxy.ext
        root/class_x/xxz.ext

        root/class_y/123.ext
        root/class_y/nsdf3.ext
        root/class_y/asd932_.ext

    Args:
        root (string): 根目錄路徑
        loader (callable): 根據給定的路徑來加載樣本的可調用函數
        extensions (list[string]): 可擴展類型列表,即能接受的圖像文件類型.
        transform (callable, optional): 用於樣本的transform函數,而後返回樣本transform後的版本
            E.g, ``transforms.RandomCrop`` for images.
        target_transform (callable, optional): 用於樣本標籤的transform函數

     Attributes:
        classes (list): 類別名列表
        class_to_idx (dict): 項目(class_name, class_index)字典,如{'cat': 0, 'dog': 1}
        samples (list): (sample path, class_index) 元組列表,即(樣本路徑, 類別索引)
        targets (list): 在數據集中每張圖片的類索引值,爲列表
    """

    def __init__(self, root, loader, extensions, transform=None, target_transform=None):
        classes, class_to_idx = self._find_classes(root) # 獲得類名和類索引,如['cat', 'dog']和{'cat': 0, 'dog': 1}
        # 返回形如[(圖像路徑, 該圖像對應的類別索引值),(),...],即對每一個圖像進行標記
        samples = make_dataset(root, class_to_idx, extensions) 
        if len(samples) == 0:
            raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
                               "Supported extensions are: " + ",".join(extensions)))

        self.root = root
        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples] #全部圖像的類索引值組成的列表

        self.transform = transform
        self.target_transform = target_transform

    def _find_classes(self, dir):
        """
        在數據集中查找類文件夾。

        Args:
            dir (string): 根目錄路徑

        Returns:
            返回元組: (classes, class_to_idx)即(類名, 類索引),其中classes即相應的目錄名,如['cat', 'dog'];class_to_idx爲形如{類名:類索引}的字典,如{'cat': 0, 'dog': 1}.

        Ensures:
            保證沒有類名是另外一個類目錄的子目錄
        """
        if sys.version_info >= (3, 5):
            # Faster and available in Python 3.5 and above
            classes = [d.name for d in os.scandir(dir) if d.is_dir()] #得到根目錄dir的全部第一層子目錄名
        else:
            classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] #效果和上面的同樣,只是版本不一樣方法不一樣
        classes.sort() #而後對類名進行排序
        class_to_idx = {classes[i]: i for i in range(len(classes))} #而後將類名和索引值一一對應的到相應字典,如{'cat': 0, 'dog': 1}
        return classes, class_to_idx #而後返回類名和類索引

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path) # 加載圖片
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self):
        return len(self.samples)

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

此時想要覆寫ImageFolder,代碼爲:post

class CustomImageFolder(ImageFolder):
    """
        爲了獲得兩張圖(其中一張是隨機選取的)的圖像和索引值信息
    """
    def __init__(self, root, transform=None):
        super(CustomImageFolder, self).__init__(root, transform)
        self.indices = range(len(self)) #該文件夾中的長度

    def __getitem__(self, index1):
        index2 = random.choice(self.indices) #從[0,indices]中隨機抽取一個數字,爲了隨機選取一張圖

        path1 = self.imgs[index1][0] #此時的self.imgs等於self.samples,即內容爲[(圖像路徑, 該圖像對應的類別索引值),(),...]
        label1 = self.imgs[index1][1]
        path2 = self.imgs[index2][0]
        label2 = self.imgs[index2][1]

        img1 = self.loader(path1)
        img2 = self.loader(path2)
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        return img1, img2, label1, label2
相關文章
相關標籤/搜索