使用pycaffe進行mnist手寫數字識別

官方引導教程html

運行環境

win10+python3.5+gpu版本的caffepython

步驟

  1. 下載數據集
  2. 將數據集轉爲lmdb
  3. 訓練
  4. 測試訓練的出來的模型

下載數據集

mnist官網下載下面4個文件windows

t10k-images.idx3-ubyte
t10k-labels.idx1-ubyte
train-images.idx3-ubyte
train-labels.idx1-ubyte

它們的結構在mnist網站上有說明orm

訓練圖片集的打標值文件 (train-labels-idx1-ubyte):htm

[offset]  [type]               [value]                     [description] 
0000     32 bit integer   0x00000801(2049) magic number (MSB first) 
0004     32 bit integer   10000                     標籤值總數 
0008     unsigned byte  ??                          標籤值
0009     unsigned byte  ??                          標籤值
........ 
xxxx     unsigned byte   ??                          標籤值

訓練圖片集文件 (train-images-idx3-ubyte):

[offset]   [type]               [value]                     [description] 
0000     32 bit integer   0x00000803(2051)  magic number 
0004     32 bit integer   10000                      圖片總數 
0008     32 bit integer    28                           單張圖片的長度像素值數量
0012     32 bit integer    28                           單張圖片的高度像素值數量
0016     unsigned byte   ??                          單像素值
0017     unsigned byte   ??                          單像素值 
........ 
xxxx     unsigned byte    ??                          單像素值

測試圖片集的打標值文件 (t10k-labels-idx1-ubyte):

[offset]  [type]               [value]                     [description] 
0000     32 bit integer   0x00000801(2049) magic number (MSB first) 
0004     32 bit integer   10000                     標籤值總數 
0008     unsigned byte  ??                          標籤值
0009     unsigned byte  ??                          標籤值
........ 
xxxx     unsigned byte   ??                          標籤值

標籤值的範圍是0-9

測試圖片集文件 (t10k-images-idx3-ubyte):

[offset]   [type]               [value]                     [description] 
0000     32 bit integer   0x00000803(2051)  magic number 
0004     32 bit integer   10000                      圖片總數 
0008     32 bit integer    28                           單張圖片的長度像素值數量
0012     32 bit integer    28                           單張圖片的高度像素值數量
0016     unsigned byte   ??                          單像素值
0017     unsigned byte   ??                          單像素值 
........ 
xxxx     unsigned byte    ??                          單像素值

數據集轉換爲lmdb

下載的數據集有兩對,一對是訓練數據圖片集和它對應的標籤值, 另外一對是測試圖片集和它對應的標籤值,這兩對文件的結構是同樣的,所以轉換爲lmdb文件時,能夠使用一樣的方法

def orgin_to_lmdb(image_file, label_file, lmdb_save_path, force_update=False):
    mean_file = '{}.binaryproto'.format(lmdb_save_path)

    if os.path.exists(mean_file) and os.path.exists(lmdb_save_path) and force_update == False:
        return

    try:
        shutil.rmtree(lmdb_save_path)
    except:
        pass
    try:
        shutil.rmtree(mean_file)
    except:
        pass

    with open(image_file, 'rb') as image_f:
        with open(label_file, 'rb') as label_f:
            # 讀取標籤文件頭的4個整型
            size = struct.calcsize('>2I')
            magic, num_items = struct.unpack_from('>2I', label_f.read(size))
            print(magic, num_items)

            # 讀取圖片文件頭的4個整型
            size = struct.calcsize('>4I')
            magic, num_images, num_rows, num_columns = struct.unpack_from('>4I', image_f.read(size))
            print(magic, num_images, num_rows, num_columns)

            map_size = num_images*num_rows*num_columns * 1.5

            # 遍歷全部圖片,將文件列表寫入到lmdb中
            with lmdb.open(lmdb_save_path,map_size=map_size) as in_db:
                with in_db.begin(write=True) as in_txn:
                    im_size = num_rows * num_columns
                    label_size = struct.calcsize('>B')
                    im_idx = 0
                    while im_idx < num_images:
                        img_item = struct.unpack_from('>B', label_f.read(label_size))[0]
                        img_buf = image_f.read(im_size)

                        datum = caffe_pb2.Datum(
                            channels=1,  # 數據集裏面的圖片是灰度圖,所以通道數設置爲1
                            width=num_columns,
                            height=num_rows,
                            label=int(img_item),
                            data=img_buf
                        )
                        in_txn.put('{:0>8d}'.format(im_idx).encode('utf8'), datum.SerializeToString())
                        im_idx += 1

    # 生成mean文件
    cmd = '{0} {1} {2}'.format(compute_image_mean, lmdb_save_path, mean_file)
    print(cmd)
    os.system(cmd)

如下代碼能夠打開lmdb查看第一張圖片

# 查看lmdb的第一張圖片
def show_lmdb_first_image(lmdb_save_path):
    with lmdb.open(lmdb_save_path, readonly=True) as lmdb_env:
        lmdb_txn = lmdb_env.begin()
        lmdb_cursor = lmdb_txn.cursor()
        datum = caffe_pb2.Datum()

        lmdb_cursor.first()
        key, value = lmdb_cursor.item()
        datum.ParseFromString(value)

        label = datum.label
        data = caffe.io.datum_to_array(datum)
        print(label, datum.channels, data.shape)
        image = data.transpose(1, 2, 0)
        cv2.imshow('cv2.png', image)
        cv2.waitKey(0)

        cv2.destroyAllWindows()

使用數據集進行訓練

使用caffe代碼目錄下的examples\mnist\lenet_solver.prototxtexamples\mnist\lenet_train_test.prototxt, 須要修改lenet_solver.prototxt中的網絡文件地址爲新的lenet_train_test.prototxt 須要修改lenet_train_test.prototxt的數據層爲剛纔生成的lmdb地址

solver = caffe.SGDSolver('lenet_solver.prototxt')
        solver.solve()

完成以後會產生兩個模型文件lenet_iter_5000.caffemodellenet_iter_10000.caffemodel

測試訓練的出來的模型

須要先生成一個網絡配置文件, 通常是改動訓練時用的網絡配置文件,這裏直接使用examples\mnist\lenet.prototxt

net = caffe.Net(
            'lenet.prototxt', # 網絡配置文件
            caffe.TEST,
            weights='lenet_iter_10000.caffemodel'  # 訓練產生的模型
        )

        transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
        transformer.set_transpose('data', (2,0,1))
        transformer.set_raw_scale('data', 255)
        # transformer.set_channel_swap('data', (2, 1, 0))  # minist用的是灰度圖 channel只有1,所以無需轉換

        # 由於minist的channel是1, 因此須要轉爲灰度圖color=False
        im = caffe.io.load_image('3.jpg', color=False)  # 打開測試圖片
        net.blobs['data'].data[0] = transformer.preprocess('data', im)
        res = net.forward()
        print(res['prob'].argmax())

測試圖片是用windows測試工具寫的幾個數字,須要黑底白字,而且圖片大小要改成28*28

  • 輸入圖片說明
  • 輸入圖片說明
  • 輸入圖片說明

有的識別會出錯。。。

相關文章
相關標籤/搜索