mnist的格式說明,以及在python3.x和python 2.x讀取mnist數據集的不一樣

有一個關於mnist的一個事例能夠參考,我以爲寫的很好:http://www.cnblogs.com/x1957/archive/2012/06/02/2531503.html
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import struct
# from bp import *
from datetime import datetime
# 數據加載器基類
class Loader(object):
    def __init__(self, path, count):
        '''
        初始化加載器
        path: 數據文件路徑
        count: 文件中的樣本個數
        '''
        self.path = path
        self.count = count
    def get_file_content(self):
        '''
        讀取文件內容
        '''
        f = open(self.path, 'rb')
        content = f.read()
        print content[:20]
        f.close()
        return content
 
    def to_int(self,h):
        return struct.unpack('B',h)[0]
 
# 圖像數據加載器
class ImageLoader(Loader):
    def get_picture(self, content, index):
        '''
        內部函數,從文件中獲取圖像
        '''
        start = index * 28 * 28 + 16
        picture = []
        # print(content[16])
        for i in range(28):
            picture.append([])
            for j in range(28):
                picture[i].append(
#在python2.7中,紅色字體部分就是對的,可是在python3.x中,藍色字體纔是對的
                    self.to_int(content[start + i * 28 + j-1:start + i * 28 + j ]))
                    self.to_int(content[start + i * 28 + j]))
        return picture
    def get_one_sample(self, picture):
        '''
        內部函數,將圖像轉化爲樣本的輸入向量
        '''
        sample = []
        for i in range(28):
            for j in range(28):
                sample.append(picture[i][j])
        return sample
    def load(self):
        '''
        加載數據文件,得到所有樣本的輸入向量
        '''
        content = self.get_file_content()
        data_set = []
        for index in range(self.count):
            data_set.append(
                self.get_one_sample(
                    self.get_picture(content, index)))
        return data_set
# 標籤數據加載器
class LabelLoader(Loader):
    def load(self):
        '''
        加載數據文件,得到所有樣本的標籤向量
        '''
        content = self.get_file_content()
        # print content[:15]
        labels = []
        for index in range(self.count):
#在python2.7中,紅色字體部分就是對的,可是在python3.x中,藍色字體纔是對的
            labels.append(self.norm(content[index + 7 :index + 8]))
            labels.append(self.norm(content[index + 8]))
        return labels
    def norm(self, label):
        '''
        內部函數,將一個值轉換爲10維標籤向量
        '''
        label_vec = []
        # print('label is \n')
        # print(label[:20])
        label_value = self.to_int(label)
        for i in range(10):
            if i == label_value:
                label_vec.append(0.9)
            else:
                label_vec.append(0.1)
        return label_vec
def get_training_data_set():
    '''
    得到訓練數據集
    '''
    filename1 = r'E:\workspace\pythonpaper\importment\dataset\train-images.idx3-ubyte'
    filename2 = r'E:\workspace\pythonpaper\importment\dataset\train-labels.idx1-ubyte'
    image_loader = ImageLoader(filename1, 60000)
    label_loader = LabelLoader(filename2, 60000)
    return image_loader.load(), label_loader.load()
def get_test_data_set():
    '''
    得到測試數據集
    '''
    filename3 = r'E:\workspace\pythonpaper\importment\dataset\t10k-images.idx3-ubyte'
    filename4 = r'E:\workspace\pythonpaper\importment\dataset\t10k-labels.idx1-ubyte'
    image_loader = ImageLoader(filename3, 10000)
    label_loader = LabelLoader(filename4, 10000)
    return image_loader.load(), label_loader.load()
 
def train_and_evaluate():
 
    train_data_set, train_labels = get_training_data_set()
    test_data_set, test_labels = get_test_data_set()
    # print '[dataset train:]\n'
    # print train_data_set[:10]
 
if __name__ == '__main__':
    train_and_evaluate()
 
一、mnist數據集格式的介紹 
上面的代碼是我參考的一個教程上的例子,它自己是用python2.7實現的,可是,由於一些緣由,我用的python3.5的環境,在實現這個代碼的時候,出現了一些問題,爲此,我也探究了一下。
mnist數據集是一個idx的文件格式,從網上下載下來的是四個壓縮文件,兩個訓練樣本的壓縮文件,兩個測試樣本的壓縮文件,在導入代碼以前須要把它們解壓縮,解壓後的文件是以idx3-ubyte爲後綴的idx的文件,這個文件是不能直接打開的,因此咱們須要編寫程序把它處理成咱們須要的內容。
以mnist數據集的train-images-idx3-ubyte爲例介紹

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  60000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel
32bit是說這個數據書32位的,8位=1B(1個字節),所以,32位=4B=4byte,咱們真正要讀出來的是value這一列,可是0000-0015的數據不是咱們須要的,第一個4B是magic的數量,第二個4B是這個文件包含多少個圖像,第三個4B是說一個圖像的有多少行,第四個4B是說一個圖像有多少列,mnist的一個樣本圖像是28*28的。從0016開始,纔是咱們須要的圖像內容,28*28=784,也就是咱們須要784個B才能讀取一個圖像,在0016如下的的description上,寫的是pixel,這是像素的意思,也就是說,一個像素就是一個1byte=1B,舉例,用1B(一個字節)的二進制表示一個十進制的3,二進制就是0000 0011,用十六進制表示3,就是\x03,3的縮寫是ETX。這樣784個像素,784行就是一個圖像樣本了。
下面是mnist數據集的 train-labels-idx1-ubyte文件的結構
TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000801(2049) magic number (MSB first)
    0004     32 bit integer  60000            number of items
    0008     unsigned byte   ??               label
    0009     unsigned byte   ??               label
    ........
    xxxx     unsigned byte   ??               label
train-labels-idx1-ubyte文件的結構的讀法和train-images-idx3-ubyte相同,前8個字節不是label的內容,從0008開始纔是一個label的內容,並且,經過觀察這個表格的offset和description字段能夠發現,一個字節是一個label。
 
 二、像素和二進制,十六進制的關係,以及python中print 的輸出的不一樣
先說明一下二進制和十六進制和十進制的關係,以及它們的縮寫的關係
ASCII控制字符
二進制 十進制 十六進制 縮寫 能夠顯示的表示法 名稱/意義
0000 0000 0 00 NUL 空字符(Null)
0000 0001 1 01 SOH 標題開始
0000 0010 2 02 STX 本文開始
0000 0011 3 03 ETX 本文結束
0000 0100 4 04 EOT 傳輸結束
0000 0101 5 05 ENQ 請求
0000 0110 6 06 ACK 確認迴應
0000 0111 7 07 BEL 響鈴
0000 1000 8 08 BS 退格
0000 1001 9 09 HT 水平定位符號
0000 1010 10 0A LF 換行鍵
0000 1011 11 0B VT 垂直定位符號
0000 1100 12 0C FF 換頁鍵
 
從open(filepath,'rb')中讀出來的是二進制的內容,print content[:5],顯示content的前5個元素,一個元素就是一個像素,這樣就有5個像素,而一個像素佔一個二進制,一個二進制就是一個字節,一個字節就是一個十六進制。
 
有個小例子能夠進一步說明一個int包含了4個字節,而一個字節是\x14這樣的形式。
>>> a=20
>>> b=400
>>> t=struct.pack('ii',a,b)
>>> t
'\x14\x00\x00\x00\x90\x01\x00\x00'
>>> len(t)
8
>>> type(a)
<type 'int'>
a是int型的,pack('ii',a,b)中的'ii'是格式,一個i對應了一個int,有兩個i,對應了兩個int,一個int型的a,佔了4個字節(\x14\x00\x00\x00),len輸出的是一個字節\x14就是一個,全部有8個\x這樣的,len(t)就是8個
 
三、struct的介紹
a=20,b=400
struct有三個方法,pack(fmt,val)方法是把val的數據按照fmt的格式轉換爲二進制數據, t=struct.pack('ii',a,b),把a,b轉換爲二進制形式'\x14\x00\x00\x00\x90\x01\x00\x00'
unpack(fmt,val)方法是把val按照fmt的格式把二進制數據轉換爲python能夠讀的數據,unpack('ii',a,b),把a,b轉換爲20,400
struct.unpack_from('>IIII' , buf , index)'>IIII'是說使用大端法從index的位置讀取4個unsinged int32
 
四、python2.7和python3.5對mnist數據集的格式引起的問題
在python2.7中,content輸出的是20個二進制的縮寫,可是在python3.5中,print content[:20]輸出的是20個十六進制。
在python2.7中,在struct.unpack('B',byte)中的content[start+i*28+j],就能夠運行,可是在python3中,這裏就須要寫成[start+i*28+j -1:start+i*28+j ]才能夠運行成功
相關文章
相關標籤/搜索