『跟着雨哥學AI』系列:詳解飛槳框架數據管道

點擊左上方藍字關注咱們php

課程簡介:git

「跟着雨哥學AI」是百度飛槳開源框架近期針對高層API推出的系列課。本課程由多位資深飛槳工程師精心打造,不只提供了從數據處理、到模型組網、模型訓練、模型評估和推理部署全流程講解;還提供了豐富的趣味案例,旨在幫助開發者更全面清晰地掌握百度飛槳框架的用法,並可以觸類旁通、靈活使用飛槳框架進行深度學習實踐。github

本章分別對內置數據集、數據集定義、數據加強、數據採樣以及數據加載這幾個功能點進行詳細的講解。web

在上個月發佈的 飛槳開源框架2.0,帶你走進全新高層API,十行代碼搞定深度學習模型開發  中,已經給你們簡單介紹了飛槳高層API的定義、特色、總體框架以及具體API。這節課咱們將介紹飛槳高層API的第一個模塊--數據管道。俗話說『九層之臺,起於壘土』,數據管道是模型訓練過程當中最重要的前置工做。在飛槳的整個框架中,數據管道包含了五個功能點:內置數據集、數據集定義、數據加強、數據採樣以及數據加載。接下來我將分別對這五個功能點進行詳細的講解。好的,那下面就讓咱們進入今天的內容吧。緩存

下載安裝命令

## CPU版本安裝命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安裝命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

什麼是數據管道?微信

在完成深度學習領域的任務時,咱們最早面臨的挑戰就是數據處理,即須要將數據處理成模型可以"看懂"的語言,從而進行模型的訓練。好比,在圖像分類任務中,咱們須要按格式處理圖像數據與其對應的標籤,而後才能將其輸入到模型中,開始訓練。在這個過程當中,咱們須要將圖片數據從jpg、png或其它格式轉換爲numpy array的格式,而後對其進行一些加工,如重置大小、旋轉變換、改變亮度等等,從而進行數據加強。因此,數據的預處理和加載方式很大程度上決定了模型最終的性能水平。傳統框架經常包含着複雜的數據加載模式,多重的預處理操做經常會勸退許多人。而飛槳框架爲了簡化數據管道的流程,對數據管道相關的場景進行了高級封裝,經過很是少許代碼,便可實現數據的處理,更愉快的進行深度學習模型研發。網絡

數據管道詳解框架

在數據管道總共包含5個模塊,分別是飛槳框架內置數據集、自定義數據集、數據加強、數據採樣以及數據加載5個部分。關係圖以下: dom

下面,讓我來一一介紹這些內容。分佈式

 

2.1 內置數據集

內置數據集介紹:

爲了節約你們處理數據時所耗費的時間和精力,飛槳框架將一些咱們經常使用到的數據集做爲領域API對用戶進行開放,用戶經過調用paddle.vision.datasets和paddle.text.datasets便可直接使用領域API,這兩個API內置包含了許多CV和NLP領域相關的常見數據集,具體以下:

import paddle
import numpy as np

paddle.__version__
'2.0.0-rc1'

print('視覺相關數據集:', paddle.vision.datasets.__all__)
print('天然語言相關數據集:', paddle.text.datasets.__all__)
視覺相關數據集: ['DatasetFolder', 'ImageFolder', 'MNIST', 'FashionMNIST', 'Flowers', 'Cifar10', 'Cifar100', 'VOC2012']
天然語言相關數據集: ['Conll05st', 'Imdb', 'Imikolov', 'Movielens', 'UCIHousing', 'WMT14', 'WMT16']

內置數據集使用:

爲了方便你們理解,這裏我演示一下如何使用內置的手寫數字識別的數據集,其餘數據集的使用方式也相似,你們能夠動手試一下哦。具體能夠見下面的代碼,注意,咱們經過使用mode參數用來標識訓練集與測試集。調用數據集接口後,相應的API會自動下載數據集到本機緩存目錄~/.cache/paddle/dataset。

import paddle.vision as vision


print("訓練集下載中...")
# 訓練數據集
train_dataset = vision.datasets.MNIST(mode='train')
print("訓練集下載完成!")
print("測試集下載中...")
# 驗證數據集
test_dataset = vision.datasets.MNIST(mode='test')
print("測試集下載完成!")
訓練集下載中...
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz 
Begin to download

Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz 
Begin to download
........
Download finished
訓練集下載完成!
測試集下載中...
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz 
Begin to download

Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz 
Begin to download
..
Download finished
測試集下載完成!

內置數據集可視化:

經過上面的步驟,咱們就定義好了訓練集與測試集,接下來,讓咱們來看一下數據集的內容吧。

import numpy as np
import matplotlib.pyplot as plt


train_data_0, train_label_0 = np.array(train_dataset[0][0]), train_dataset[0][1]
train_data_0 = train_data_0.reshape([28, 28])

plt.figure(figsize=(2, 2))
plt.imshow(train_data_0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(train_label_0))
train_data0 label is: [5]

從上例中能夠看出,train_dataset 是一個 map-style 式的數據集,咱們能夠經過下標直接獲取單個樣本的圖像數據與標籤,從而進行可視化。

Note: map-style 是指能夠經過下標的方式來獲取指定樣本,除此以外,還有 iterable-style 式的數據集,只能經過迭代的方式來獲取樣本,具體說明能夠見下一節。

 

2.2 數據集定義

有同窗提出雖然飛槳框架提供了許多領域數據集供咱們使用,可是在實際的使用場景中,若是咱們須要使用已有的數據來訓練模型怎麼辦呢?別慌,飛槳也貼心地準備了 map-style 的 paddle.io.Dataset 基類 和 iterable-style 的 paddle.io.IterableDataset 基類 ,來完成數據集定義。此外,針對一些特殊的場景,飛槳框架也提供了 paddle.io.TensorDataset 基類,能夠直接處理 Tensor 數據爲 dataset,一鍵完成數據集的定義。

讓咱們來看一下它們的使用方式吧~

paddle.io.Dataset的使用方式:

這個是咱們最推薦使用的API,來完成數據的定義。使用 paddle.io.Dataset,最後會返回一個 map-style 的 Dataset 類。能夠用於後續的數據加強、數據加載等。而使用 paddle.io.Dataset 也很是簡單,只須要按格式完成如下四步便可。

class MyDataset(paddle.io.Dataset):
    """
    步驟一:繼承paddle.io.Dataset類
    """
    def __init__(self, mode='train'):
        """
        步驟二:實現構造函數,定義數據讀取方式,劃分訓練和測試數據集
        """
        super(MyDataset, self).__init__()

        if mode == 'train':
            self.data = [
                ['train_image_0.jpg', '1'],
                ['train_image_1.jpg', '2'],
                ['train_image_2.jpg', '3'],
                ['train_image_3.jpg', '4'],
            ]
        else:
            self.data = [
                ['test_image_0.jpg', '1'],
                ['test_image_1.jpg', '2'],
                ['test_image_2.jpg', '3'],
                ['test_image_3.jpg', '4'],
            ]

    def _load_img(self, image_path):
        # 實際使用時使用Pillow相關庫進行圖片讀取便可,這裏咱們對數據先作個模擬
        image = np.random.randn(32, 32, 3)

        return image

    def __getitem__(self, index):
        """
        步驟三:實現__getitem__方法,定義指定index時如何獲取數據,並返回單條數據(訓練數據,對應的標籤)
        """
        image = self._load_img(self.data[index][0])
        label = self.data[index][1]

        return image, np.array(label, dtype='int64')

    def __len__(self):
        """
        步驟四:實現__len__方法,返回數據集總數目
        """
        return len(self.data)

# 測試定義的數據集
train_dataset = MyDataset(mode='train')
test_dataset = MyDataset(mode='test')

print('=============train dataset=============')
for image, label in train_dataset:
    print('image shape: {}, label: {}'.format(image.shape, label))

print('=============evaluation dataset=============')
for image, label in test_dataset:
    print('image shape: {}, label: {}'.format(image.shape, label))
=============train dataset=============
image shape: (32, 32, 3), label: 1
image shape: (32, 32, 3), label: 2
image shape: (32, 32, 3), label: 3
image shape: (32, 32, 3), label: 4
=============evaluation dataset=============
image shape: (32, 32, 3), label: 1
image shape: (32, 32, 3), label: 2
image shape: (32, 32, 3), label: 3
image shape: (32, 32, 3), label: 4

paddle.io.Dataset實戰:

看了上面的例子,你是否想本身動手寫一個Dataset呢?就讓我用實戰來演示一下:

# 下載訓練集
!wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
# 下載訓練集標籤
!wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
# 下載測試集
!wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
# 下載測試集標籤
!wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz

import os
import gzip


class FashionMNISTDataset(paddle.io.Dataset):
    """
    步驟一:繼承paddle.io.Dataset類
    """
    def __init__(self, path='./', mode='train'):
        """
        步驟二:實現構造函數,定義數據讀取方式,劃分訓練和測試數據集
        """
        super(FashionMNISTDataset, self).__init__()

        images_data_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % mode)
        labels_data_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % mode)
        with gzip.open(labels_data_path, 'rb') as lbpath:
            self.labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

        with gzip.open(images_data_path, 'rb') as imgpath:
            self.images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(self.labels), 784)

    def __getitem__(self, index):
        """
        步驟三:實現__getitem__方法,定義指定index時如何獲取數據,並返回單條數據(訓練數據,對應的標籤)
        """
        image = self.images[index]
        label = self.labels[index]

        return image, label

    def __len__(self):
        """
        步驟四:實現__len__方法,返回數據集總數目
        """
        return len(self.images)

# 測試定義的數據集
fashion_mnist_train_dataset = FashionMNISTDataset(mode='train')
fashion_mnist_test_dataset = FashionMNISTDataset(mode='t10k')

# 可視化訓練集
fashion_mnist_train_data_0 = np.array(fashion_mnist_train_dataset[0][0])
fashion_mnist_train_label_0 = fashion_mnist_train_dataset[0][1]
fashion_mnist_train_data_0 = fashion_mnist_train_data_0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(fashion_mnist_train_data_0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(fashion_mnist_train_label_0))
train_data0 label is: 9

paddle.io.IterableDataset 的使用方式

使用 paddle.io.IterableDataset,最後會返回一個 iterable-style 的 Dataset 類。而使用 paddle.io.IterableDataset 也很是簡單,只須要按格式完成如下兩步便可。

import math
import paddle
import numpy as np
from paddle.io import IterableDataset, DataLoader, get_worker_info

class SplitedIterableDataset(IterableDataset):
    """
    步驟一:繼承paddle.io.Dataset類
    """
    def __init__(self, start, end):
        self.start = start
        self.end = end

    def __iter__(self):
        """
        步驟二:實現__iter__方法,
        """    
        worker_info = get_worker_info()
        if worker_info is None:
            iter_start = self.start
            iter_end = self.end
        else:
            per_worker = int(
                math.ceil((self.end - self.start) / float(
                    worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        for i in range(iter_start, iter_end):
            yield np.array([i])


dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(dataset, num_workers=2, batch_size=1, drop_last=True)

for data in dataloader:
    print(data[0].numpy())
[[2]]
[[6]]
[[3]]
[[7]]
[[4]]
[[8]]
[[5]]

paddle.io.TensorDataset的使用方式

上面介紹了兩種數據集的定義方式,分別經過繼承paddle.io.Dataset與paddle.io.IterableDataset就能夠實現。不過,還有一種場景,若是咱們已經有了Tensor類型的數據,想要快速、直接的建立Dataset,而不去實現paddle.io.Dataset的各類方法,能夠麼?這時,咱們就可使用 paddle.io.TensorDataset,直接將 Tensor 類型的 數據與標籤傳入 TensorDataset 類便可。

快來看看這是怎麼實現的吧:

from paddle.io import TensorDataset


input_np = np.random.random([2, 3, 4]).astype('float32')
input_tensor = paddle.to_tensor(input_np)
label_np = np.random.random([2, 1]).astype('int32')
label_tensor = paddle.to_tensor(label_np)

dataset = TensorDataset([input_tensor, label_tensor])

for i in range(len(dataset)):
    input, label = dataset[i]
    print(input, label)
Tensor(shape=[3, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
       [[0.91451722, 0.94088864, 0.52030772, 0.80783033],
        [0.74379814, 0.18669823, 0.41893899, 0.89299613],
        [0.67413408, 0.82801068, 0.02079745, 0.95862854]]) Tensor(shape=[1], dtype=int32, place=CUDAPlace(0), stop_gradient=True,
       [0])
Tensor(shape=[3, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
       [[0.30733261, 0.82390237, 0.99652219, 0.93594497],
        [0.62558615, 0.83836132, 0.34213212, 0.72257715],
        [0.80075997, 0.38913822, 0.25709155, 0.00520579]]) Tensor(shape=[1], dtype=int32, place=CUDAPlace(0), stop_gradient=True,
       [0])

能夠看出,咱們將Tensor類型的 input 與 label 直接傳入TensorDataset中,就能夠完成 Dataset 的定義,徹底不須要實現上述自定義的那四個步驟。在咱們的實際使用中,若是想要簡單的作個測試,徹底能夠直接使用TensorDataset來建立數據集。那麼,使用 TensorDataset 有什麼要求呢?只有一個要求,就是傳入的 Tensor,它們的第一維維度要相同,從上例中能夠看出, input 與 label 的第一維都是2,表明數據集的大小。

 

2.3 數據加強

在訓練模型的過程當中,咱們偶爾會遇到過擬合的問題。這時,最好的作法是增長訓練集的數量,以此提高模型的泛化能力。可是,因爲獲取數據集的成本比較高,因此一般咱們都會採用加強訓練數據的方式對數據進行處理,從而獲得更多不一樣的數據集。

在計算機視覺領域中,常見的數據加強的方式包括隨機裁剪、旋轉變換、改變圖像亮度、對比度等等。能夠看出,這些方法都是對圖像的常見處理方式,因此,飛槳框架直接提供了這類的API,定義在paddle.vision.transofrms下,包含了計算機視覺中對圖像的各類經常使用的處理,具體以下:

print("飛槳支持的數據預處理方式:" + str(paddle.vision.transforms.__all__))
飛槳支持的數據預處理方式:['BaseTransform', 'Compose', 'Resize', 'RandomResizedCrop', 'CenterCrop', 'RandomHorizontalFlip', 'RandomVerticalFlip', 'Transpose', 'Normalize', 'BrightnessTransform', 'SaturationTransform', 'ContrastTransform', 'HueTransform', 'ColorJitter', 'RandomCrop', 'Pad', 'RandomRotation', 'Grayscale', 'ToTensor', 'to_tensor', 'hflip', 'vflip', 'resize', 'pad', 'rotate', 'to_grayscale', 'crop', 'center_crop', 'adjust_brightness', 'adjust_contrast', 'adjust_hue', 'normalize']

那麼該怎麼使用呢?咱們這裏分兩種場景來介紹,一種是飛槳內置數據集使用數據加強,另外一種是自定義數據集使用數據加強。

內置數據集使用數據加強:

內置數據集使用數據加強的方式很是簡單,咱們能夠直接定義一個數據預處理的方式,而後將其做爲參數,在加載內置數據集的時候,傳給 transform 參數便可;而若是咱們想對一個數據集進行多個數據預處理的方式,能夠先定義一個 transform 的容器 Compose,將咱們須要的數據預處理方法以 list 的格式傳入 Compose,而後在加載內置數據集的時候,傳給 transform 參數便可。

具體能夠看下面的例子:

import paddle.vision.transforms as T


# 方式一 只對圖像進行調整亮度的操做
transform = T.BrightnessTransform(0.4)
# 經過transform參數傳遞定義好的數據增方法便可完成對自帶數據集的數據加強
train_dataset_without_transform = vision.datasets.Cifar10(mode='train')
train_dataset_with_transform = vision.datasets.Cifar10(mode='train', transform=transform)

index = 10
print("未調整亮度的圖像")
train_dataset_without_data_0 = np.array(train_dataset_without_transform[index][0])
train_dataset_without_data_0 = train_dataset_without_data_0.astype('float32') / 255.
plt.imshow(train_dataset_without_data_0)
未調整亮度的圖像
<matplotlib.image.AxesImage at 0x7fb13e129090>

print("調整亮度的圖像")
train_dataset_with_data_0 = np.array(train_dataset_with_transform[index][0])
train_dataset_with_data_0 = train_dataset_with_data_0.astype('float32') / 255.
plt.imshow(train_dataset_with_data_0)
調整亮度的圖像
<matplotlib.image.AxesImage at 0x7fb19b1b5f90>

import paddle.vision.transforms as T

# 方式二 對圖像進行多種操做
transform = T.Compose([T.BrightnessTransform(0.4), T.ContrastTransform(0.4)])
# 經過transform參數傳遞定義好的數據增方法便可完成對自帶數據集的數據加強
train_dataset_without_compose = vision.datasets.Cifar10(mode='train')
train_dataset_with_compose = vision.datasets.Cifar10(mode='train', transform=transform)

index = 10
print("未調整的圖像")
train_dataset_without_compose_data_0 = np.array(train_dataset_without_compose[index][0])
train_dataset_without_compose_data_0 = train_dataset_without_compose_data_0.astype('float32') / 255.
plt.imshow(train_dataset_without_compose_data_0)
未調整的圖像
<matplotlib.image.AxesImage at 0x7fb13065fb90>

print("多種調整後的圖像")
train_dataset_with_compose_data_0 = np.array(train_dataset_with_compose[index][0])
train_dataset_with_compose_data_0 = train_dataset_with_compose_data_0.astype('float32') / 255.
plt.imshow(train_dataset_with_compose_data_0)
多種調整後的圖像
<matplotlib.image.AxesImage at 0x7fb1b818c610>

自定義數據集使用數據加強:

針對自定義數據集使用數據加強的方式, 比較直觀的方式是在在數據集的構造函數中進行數據加強方法的定義,以後對__getitem__中返回的數據進行應用。咱們以上述中FashionMNIST數據集爲例來講明,具體以下:

class FashionMNISTDataset(paddle.io.Dataset):
    """
    步驟一:繼承paddle.io.Dataset類
    """
    def __init__(self, path='./', mode='train', transform='None'):
        """
        步驟二:實現構造函數,定義數據讀取方式,劃分訓練和測試數據集
        """
        super(FashionMNISTDataset, self).__init__()

        images_data_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % mode)
        labels_data_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % mode)
        with gzip.open(labels_data_path, 'rb') as lbpath:
            self.labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

        with gzip.open(images_data_path, 'rb') as imgpath:
            self.images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(self.labels), 784)
        self.transform = None
        if transform != 'None':
            self.transform = transform

    def __getitem__(self, index):
        """
        步驟三:實現__getitem__方法,定義指定index時如何獲取數據,並返回單條數據(訓練數據,對應的標籤)
        """
        if self.transform:
            image = self.transform(self.images[index].reshape(28, 28))
        else:
            image = self.images[index]
        label = self.labels[index]

        return image, label

    def __len__(self):
        """
        步驟四:實現__len__方法,返回數據集總數目
        """
        return len(self.images)

# 測試未處理的數據集
fashion_mnist_train_dataset_without_transform = FashionMNISTDataset(mode='train')

# 可視化
fashion_mnist_train_dataset_without_transform = np.array(fashion_mnist_train_dataset_without_transform[0][0])
fashion_mnist_train_dataset_without_transform = fashion_mnist_train_dataset_without_transform.reshape([28, 28])
plt.imshow(fashion_mnist_train_dataset_without_transform, cmap=plt.cm.binary)
<matplotlib.image.AxesImage at 0x7fb130421ed0>

# 測試處理的數據集
from paddle.vision.transforms import RandomVerticalFlip
fashion_mnist_train_dataset_with_transform = FashionMNISTDataset(mode='train', transform=RandomVerticalFlip(0.4))

# 可視化
fashion_mnist_train_dataset_with_transform = np.array(fashion_mnist_train_dataset_with_transform[0][0])
fashion_mnist_train_dataset_with_transform = fashion_mnist_train_dataset_with_transform.reshape([28, 28])
plt.imshow(fashion_mnist_train_dataset_with_transform, cmap=plt.cm.binary)
<matplotlib.image.AxesImage at 0x7fb130367b50>

 

2.4 數據加載

當咱們定義了數據集後,就須要加載數據集。咱們能夠經過 paddle.io.DataLoader 完成數據的加載。

train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)

for batch_id, data in enumerate(train_loader()):
    x_data = data[0]
    y_data = data[1]
    print(x_data.numpy().shape)
    print(y_data.numpy().shape)
    break
(4, 32, 32, 3)
(4,)

DataLoader 能夠加載咱們定義好的數據集。雖然看起來很簡單是很簡單的操做,可是,DataLoader 的參數中包含了許多強大的功能。如 shuffle 設爲 True, 能夠對下標進行隨機打散的操做;drop_last 設爲 True 能夠丟掉最後一個不知足 batch_size 大小的 batch;num_workers 能夠設置多個子進程來加速數據加載。此外,咱們還能夠針對不一樣的數據集,設置不一樣的採樣器,來完成數據的採樣。

 

2.5 數據採樣

飛槳框架提供了多種數據採樣器,用於不一樣的場景,來提高訓練模型的泛化性能。飛槳框架包含的採樣器以下:paddle.io.BatchSampler 、 paddle.io.DistributedBatchSampler 、paddle.io.RandomSampler、paddle.io.SequenceSampler 等,接下來來一一介紹。

from paddle.io import SequenceSampler, RandomSampler, BatchSampler, DistributedBatchSampler

class RandomDataset(paddle.io.Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __getitem__(self, idx):
        image = np.random.random([784]).astype('float32')
        label = np.random.randint(0, 9, (1, )).astype('int64')
        return image, label

    def __len__(self):
        return self.num_samples

train_dataset = RandomDataset(100)

print('-----------------順序採樣----------------')
sampler = SequenceSampler(train_dataset)
batch_sampler = BatchSampler(sampler=sampler, batch_size=10)

for index in batch_sampler:
    print(index)

print('-----------------隨機採樣----------------')
sampler = RandomSampler(train_dataset)
batch_sampler = BatchSampler(sampler=sampler, batch_size=10)

for index in batch_sampler:
    print(index)

print('-----------------分佈式採樣----------------')
batch_sampler = DistributedBatchSampler(train_dataset, num_replicas=2, batch_size=10)

for index in batch_sampler:
    print(index)

-----------------順序採樣----------------
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
[30, 31, 32, 33, 34, 35, 36, 37, 38, 39]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
[50, 51, 52, 53, 54, 55, 56, 57, 58, 59]
[60, 61, 62, 63, 64, 65, 66, 67, 68, 69]
[70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
[80, 81, 82, 83, 84, 85, 86, 87, 88, 89]
[90, 91, 92, 93, 94, 95, 96, 97, 98, 99]
-----------------隨機採樣----------------
[9, 7, 54, 93, 84, 14, 12, 46, 67, 72]
[10, 57, 32, 61, 38, 71, 63, 51, 37, 11]
[21, 76, 69, 22, 48, 88, 19, 59, 47, 60]
[89, 85, 31, 80, 91, 30, 50, 52, 39, 3]
[70, 45, 62, 75, 35, 8, 96, 94, 5, 98]
[49, 33, 28, 13, 18, 42, 90, 0, 36, 79]
[81, 15, 6, 78, 40, 86, 2, 23, 95, 43]
[87, 65, 68, 25, 99, 26, 73, 82, 1, 53]
[77, 29, 17, 44, 55, 4, 56, 64, 97, 83]
[66, 41, 16, 74, 92, 34, 27, 24, 58, 20]
-----------------分佈式採樣----------------
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
[60, 61, 62, 63, 64, 65, 66, 67, 68, 69]
[80, 81, 82, 83, 84, 85, 86, 87, 88, 89]

總結

恭喜同窗們學會了數據集的下載、數據集的自定義、數據的預處理以及數據的批加載等知識,你們已經能夠很好地應對模型訓練任務的第一步啦。那麼今天的課程到這裏就結束了,對課程內容有疑問或者建議的同窗能夠在評論區留言,看到後我會及時回覆哦,最後但願你們fork一下該項目,否則就找不到這個課程了。我是雨哥,下節課見~

下載安裝命令

## CPU版本安裝命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安裝命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

有任何問題能夠在本項目中評論或到飛槳Github倉庫(連接)提交Issue。

歡迎掃碼加入飛槳框架高層API技術交流羣

·飛槳官網地址·

https://www.paddlepaddle.org.cn/

·飛槳開源框架項目地址·

GitHub: https://github.com/PaddlePaddle/Paddle 

Gitee: https://gitee.com/paddlepaddle/Paddle

微信號 : PaddleOpenSource

飛槳(PaddlePaddle)以百度多年的深度學習技術研究和業務應用爲基礎,是中國首個開源開放、技術領先、功能完備的產業級深度學習平臺,包括飛槳開源平臺和飛槳企業版。飛槳開源平臺包含核心框架、基礎模型庫、端到端開發套件與工具組件,持續開源核心能力,爲產業、學術、科研創新提供基礎底座。飛槳企業版基於飛槳開源平臺,針對企業級需求加強了相應特性,包含零門檻AI開發平臺EasyDL和全功能AI開發平臺BML。EasyDL主要面向中小企業,提供零門檻、預置豐富網絡和模型、便捷高效的開發平臺;BML是爲大型企業提供的功能全面、可靈活定製和被深度集成的開發平臺。

本文同步分享在 博客「飛槳PaddlePaddle」(CSDN)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。

相關文章
相關標籤/搜索