Python3讀取深度學習CIFAR-10數據集出現的若干問題解決

今天在看網上的視頻學習深度學習的時候,用到了CIFAR-10數據集。當我興高采烈的運行代碼時,卻發現了一些錯誤:

# -*- coding: utf-8 -*-
import pickle as p
import numpy as np
import os


def load_CIFAR_batch(filename):
""" 載入cifar數據集的一個batch """
with open(filename, 'r') as f:
datadict = p.load(f)
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y


def load_CIFAR10(ROOT):
""" 載入cifar所有數據 """
xs = []
ys = []
for b in range(1, 6):
f = os.path.join(ROOT, 'data_batch_%d' % (b,))
X, Y = load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
Xtr = np.concatenate(xs)
Ytr = np.concatenate(ys)
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return Xtr, Ytr, Xte, Yte
複製代碼

錯誤代碼以下:

'gbk' codec can't decode byte 0x80 in position 0: illegal multibyte sequence複製代碼
因而乎開始各類搜索問題,問大佬,網上的答案都是相似:


然而並無解決問題!仍是錯誤的!(我大概搜索了一下午吧,都是上面的答案)數據庫

哇,就當我很絕望的時候,我終於發現了一個新奇的答案,抱着試一試的態度,嘗試了一下:

def load_CIFAR_batch(filename):
""" 載入cifar數據集的一個batch """
with open(filename, 'rb') as f:
datadict = p.load(f, encoding='latin1')
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y複製代碼
居然成功了,這裏沒有報錯了!欣喜之餘,我就很好奇,encoding='latin1'究竟是啥玩意呢,之前沒有見過啊?因而,我搜索了一下,瞭解到:
 Latin1是ISO-8859-1的別名,有些環境下寫做Latin-1。ISO-8859-1編碼是單字節編碼,向下兼容ASCII,其編碼範圍是0x00-0xFF,0x00-0x7F之間徹底和ASCII一致,0x80-0x9F之間是控制字符,0xA0-0xFF之間是文字符號。

由於ISO-8859-1編碼範圍使用了單字節內的全部空間,在支持ISO-8859-1的系統中傳輸和存儲其餘任何編碼的字節流都不會被拋棄。換言之,把其餘任何編碼的字節流看成ISO-8859-1編碼看待都沒有問題。這是個很重要的特性,MySQL數據庫默認編碼是Latin1就是利用了這個特性。ASCII編碼是一個7位的容器,ISO-8859-1編碼是一個8位的容器。
還沒等我高興起來,運行後,又發現了一個問題:

memory error複製代碼
什麼鬼?內存錯誤!哇,原來是數據大小的問題。

X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")複製代碼
這告訴咱們每批數據都是10000 * 3 * 32 * 32,至關於超過3000萬個浮點數。 float數據類型實際上與float64相同,意味着每一個數字大小佔8個字節。這意味着每一個批次佔用至少240 MB。你加載6這些(5訓練+ 1測試)在總產量接近1.4 GB的數據。

for b in range(1, 2):
f = os.path.join(ROOT, 'data_batch_%d' % (b,))
X, Y = load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
複製代碼


因此若有可能,如上代碼所示只能一次運行一批。

到此爲止,錯誤基本搞定,下面貼出正確代碼:
# -*- coding: utf-8 -*-
import pickle as p
import numpy as np
import os


def load_CIFAR_batch(filename):
""" 載入cifar數據集的一個batch """
with open(filename, 'rb') as f:
datadict = p.load(f, encoding='latin1')
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y


def load_CIFAR10(ROOT):
""" 載入cifar所有數據 """
xs = []
ys = []
for b in range(1, 2):
f = os.path.join(ROOT, 'data_batch_%d' % (b,))
X, Y = load_CIFAR_batch(f)
xs.append(X) #將全部batch整合起來
ys.append(Y)
Xtr = np.concatenate(xs) #使變成行向量,最終Xtr的尺寸爲(50000,32,32,3)
Ytr = np.concatenate(ys)
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return Xtr, Ytr, Xte, Yte


import numpy as np
from julyedu.data_utils import load_CIFAR10
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# 載入CIFAR-10數據集
cifar10_dir = 'julyedu/datasets/cifar-10-batches-py'
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

# 看看數據集中的一些樣本:每一個類別展現一些
print('Training data shape: ', X_train.shape)
print('Training labels shape: ', y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', y_test.shape)
複製代碼

順便看一下CIFAR-10數據組成:


更多內容,可關注個人我的公衆號
bash

                                                   

相關文章
相關標籤/搜索