官方引導教程html
win10+python3.5+gpu版本的caffepython
mnist官網下載下面4個文件windows
train-images-idx3-ubyte.gz: 訓練圖片集 (9912422 bytes)網絡
train-labels-idx1-ubyte.gz: 訓練圖片集的打標值 (28881 bytes)工具
t10k-images-idx3-ubyte.gz: 測試圖片集 (1648877 bytes)測試
t10k-labels-idx1-ubyte.gz: 測試圖片集的打標值(4542 bytes)網站
分別解壓這四個文件獲得:code
t10k-images.idx3-ubyte t10k-labels.idx1-ubyte train-images.idx3-ubyte train-labels.idx1-ubyte
它們的結構在mnist網站上有說明orm
訓練圖片集的打標值文件 (train-labels-idx1-ubyte):htm
[offset] [type] [value] [description] 0000 32 bit integer 0x00000801(2049) magic number (MSB first) 0004 32 bit integer 10000 標籤值總數 0008 unsigned byte ?? 標籤值 0009 unsigned byte ?? 標籤值 ........ xxxx unsigned byte ?? 標籤值
訓練圖片集文件 (train-images-idx3-ubyte):
[offset] [type] [value] [description] 0000 32 bit integer 0x00000803(2051) magic number 0004 32 bit integer 10000 圖片總數 0008 32 bit integer 28 單張圖片的長度像素值數量 0012 32 bit integer 28 單張圖片的高度像素值數量 0016 unsigned byte ?? 單像素值 0017 unsigned byte ?? 單像素值 ........ xxxx unsigned byte ?? 單像素值
測試圖片集的打標值文件 (t10k-labels-idx1-ubyte):
[offset] [type] [value] [description] 0000 32 bit integer 0x00000801(2049) magic number (MSB first) 0004 32 bit integer 10000 標籤值總數 0008 unsigned byte ?? 標籤值 0009 unsigned byte ?? 標籤值 ........ xxxx unsigned byte ?? 標籤值
標籤值的範圍是0-9
測試圖片集文件 (t10k-images-idx3-ubyte):
[offset] [type] [value] [description] 0000 32 bit integer 0x00000803(2051) magic number 0004 32 bit integer 10000 圖片總數 0008 32 bit integer 28 單張圖片的長度像素值數量 0012 32 bit integer 28 單張圖片的高度像素值數量 0016 unsigned byte ?? 單像素值 0017 unsigned byte ?? 單像素值 ........ xxxx unsigned byte ?? 單像素值
下載的數據集有兩對,一對是訓練數據圖片集和它對應的標籤值, 另外一對是測試圖片集和它對應的標籤值,這兩對文件的結構是同樣的,所以轉換爲lmdb文件時,能夠使用一樣的方法
def orgin_to_lmdb(image_file, label_file, lmdb_save_path, force_update=False): mean_file = '{}.binaryproto'.format(lmdb_save_path) if os.path.exists(mean_file) and os.path.exists(lmdb_save_path) and force_update == False: return try: shutil.rmtree(lmdb_save_path) except: pass try: shutil.rmtree(mean_file) except: pass with open(image_file, 'rb') as image_f: with open(label_file, 'rb') as label_f: # 讀取標籤文件頭的4個整型 size = struct.calcsize('>2I') magic, num_items = struct.unpack_from('>2I', label_f.read(size)) print(magic, num_items) # 讀取圖片文件頭的4個整型 size = struct.calcsize('>4I') magic, num_images, num_rows, num_columns = struct.unpack_from('>4I', image_f.read(size)) print(magic, num_images, num_rows, num_columns) map_size = num_images*num_rows*num_columns * 1.5 # 遍歷全部圖片,將文件列表寫入到lmdb中 with lmdb.open(lmdb_save_path,map_size=map_size) as in_db: with in_db.begin(write=True) as in_txn: im_size = num_rows * num_columns label_size = struct.calcsize('>B') im_idx = 0 while im_idx < num_images: img_item = struct.unpack_from('>B', label_f.read(label_size))[0] img_buf = image_f.read(im_size) datum = caffe_pb2.Datum( channels=1, # 數據集裏面的圖片是灰度圖,所以通道數設置爲1 width=num_columns, height=num_rows, label=int(img_item), data=img_buf ) in_txn.put('{:0>8d}'.format(im_idx).encode('utf8'), datum.SerializeToString()) im_idx += 1 # 生成mean文件 cmd = '{0} {1} {2}'.format(compute_image_mean, lmdb_save_path, mean_file) print(cmd) os.system(cmd)
如下代碼能夠打開lmdb查看第一張圖片
# 查看lmdb的第一張圖片 def show_lmdb_first_image(lmdb_save_path): with lmdb.open(lmdb_save_path, readonly=True) as lmdb_env: lmdb_txn = lmdb_env.begin() lmdb_cursor = lmdb_txn.cursor() datum = caffe_pb2.Datum() lmdb_cursor.first() key, value = lmdb_cursor.item() datum.ParseFromString(value) label = datum.label data = caffe.io.datum_to_array(datum) print(label, datum.channels, data.shape) image = data.transpose(1, 2, 0) cv2.imshow('cv2.png', image) cv2.waitKey(0) cv2.destroyAllWindows()
使用caffe代碼目錄下的examples\mnist\lenet_solver.prototxt
和examples\mnist\lenet_train_test.prototxt
, 須要修改lenet_solver.prototxt
中的網絡文件地址爲新的lenet_train_test.prototxt
須要修改lenet_train_test.prototxt
的數據層爲剛纔生成的lmdb地址
solver = caffe.SGDSolver('lenet_solver.prototxt') solver.solve()
完成以後會產生兩個模型文件lenet_iter_5000.caffemodel
和lenet_iter_10000.caffemodel
須要先生成一個網絡配置文件, 通常是改動訓練時用的網絡配置文件,這裏直接使用examples\mnist\lenet.prototxt
net = caffe.Net( 'lenet.prototxt', # 網絡配置文件 caffe.TEST, weights='lenet_iter_10000.caffemodel' # 訓練產生的模型 ) transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) transformer.set_transpose('data', (2,0,1)) transformer.set_raw_scale('data', 255) # transformer.set_channel_swap('data', (2, 1, 0)) # minist用的是灰度圖 channel只有1,所以無需轉換 # 由於minist的channel是1, 因此須要轉爲灰度圖color=False im = caffe.io.load_image('3.jpg', color=False) # 打開測試圖片 net.blobs['data'].data[0] = transformer.preprocess('data', im) res = net.forward() print(res['prob'].argmax())
測試圖片是用windows測試工具寫的幾個數字,須要黑底白字,而且圖片大小要改成28*28
有的識別會出錯。。。