KNN分類算法實現手寫數字識別

需求:python

利用一個手寫數字「先驗數據」集,使用knn算法來實現對手寫數字的自動識別;git

先驗數據(訓練數據)集:算法

♦數據維度比較大,樣本數比較多。app

♦ 數據集包括數字0-9的手寫體。函數

♦每一個數字大約有200個樣本。測試

♦每一個樣本保持在一個txt文件中。spa

♦手寫體圖像自己的大小是32x32的二值圖,轉換到txt文件保存後,內容也是32x32個數字,0或者1,以下:rest

數據集壓縮包解壓後有兩個目錄:(將這兩個目錄文件夾拷貝的項目路徑下E:/KNNCase/digits/code

♦目錄trainingDigits存放的是大約2000個訓練數據orm

♦目錄testDigits存放大約900個測試數據。

 

模型分析:

一、手寫體由於每一個人,甚至每次寫的字都不會徹底精確一致,因此,識別手寫體的關鍵是「類似度」

二、既然是要求樣本之間的類似度,那麼,首先須要將樣本進行抽象,將每一個樣本變成一系列特徵數據(即特徵向量)

三、手寫體在直觀上就是一個個的圖片,而圖片是由上述圖示中的像素點來描述的,樣本的類似度其實就是像素的位置和顏色之間的組合的類似度

四、所以,將圖片的像素按照固定順序讀取到一個個的向量中,便可很好地表示手寫體樣本

五、抽象出了樣本向量,及類似度計算模型,便可應用KNN來實現

 

python實現:

新建一個KNN.py腳本文件,文件裏面包含四個函數:

1) 一個用來生成將每一個樣本的txt文件轉換爲對應的一個向量,

2) 一個用來加載整個數據集,

3) 一個實現kNN分類算法。

4) 最後就是實現加載、測試的函數。

  1 #!/usr/bin/python
  2 # coding=utf-8
  3 #########################################
  4 # kNN: k Nearest Neighbors
  5 
  6 # 參數:        inX: vector to compare to existing dataset (1xN)
  7 #             dataSet: size m data set of known vectors (NxM)
  8 #             labels: data set labels (1xM vector)
  9 #             k: number of neighbors to use for comparison
 10 
 11 # 輸出:     多數類
 12 #########################################
 13 
 14 from numpy import *
 15 import operator
 16 import os
 17 
 18 
 19 # KNN分類核心方法
 20 def kNNClassify(newInput, dataSet, labels, k):
 21     numSamples = dataSet.shape[0]  # shape[0]表明行數
 22 
 23     # # step 1: 計算歐式距離
 24     # tile(A, reps): 將A重複reps次來構造一個矩陣
 25     # the following copy numSamples rows for dataSet
 26     diff = tile(newInput, (numSamples, 1)) - dataSet  # Subtract element-wise
 27     squaredDiff = diff ** 2  # squared for the subtract
 28     squaredDist = sum(squaredDiff, axis = 1)   # sum is performed by row
 29     distance = squaredDist ** 0.5
 30 
 31     # # step 2: 對距離排序
 32     # argsort()返回排序後的索引
 33     sortedDistIndices = argsort(distance)
 34 
 35     classCount = {}  # 定義一個空的字典
 36     for i in xrange(k):
 37         # # step 3: 選擇k個最小距離
 38         voteLabel = labels[sortedDistIndices[i]]
 39 
 40         # # step 4: 計算類別的出現次數
 41         # when the key voteLabel is not in dictionary classCount, get()
 42         # will return 0
 43         classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
 44 
 45     # # step 5: 返回出現次數最多的類別做爲分類結果
 46     maxCount = 0
 47     for key, value in classCount.items():
 48         if value > maxCount:
 49             maxCount = value
 50             maxIndex = key
 51 
 52     return maxIndex
 53 
 54 # 將圖片轉換爲向量
 55 def  img2vector(filename):
 56     rows = 32
 57     cols = 32
 58     imgVector = zeros((1, rows * cols))
 59     fileIn = open(filename)
 60     for row in xrange(rows):
 61         lineStr = fileIn.readline()
 62         for col in xrange(cols):
 63             imgVector[0, row * 32 + col] = int(lineStr[col])
 64 
 65     return imgVector
 66 
 67 # 加載數據集
 68 def loadDataSet():
 69     # # step 1: 讀取訓練數據集
 70     print "---Getting training set..."
 71     dataSetDir = 'E:/KNNCase/digits/'
 72     trainingFileList = os.listdir(dataSetDir + 'trainingDigits')  # 加載測試數據
 73     numSamples = len(trainingFileList)
 74 
 75     train_x = zeros((numSamples, 1024))
 76     train_y = []
 77     for i in xrange(numSamples):
 78         filename = trainingFileList[i]
 79 
 80         # get train_x
 81         train_x[i, :] = img2vector(dataSetDir + 'trainingDigits/%s' % filename)
 82 
 83         # get label from file name such as "1_18.txt"
 84         label = int(filename.split('_')[0]) # return 1
 85         train_y.append(label)
 86 
 87     # # step 2:讀取測試數據集
 88     print "---Getting testing set..."
 89     testingFileList = os.listdir(dataSetDir + 'testDigits') # load the testing set
 90     numSamples = len(testingFileList)
 91     test_x = zeros((numSamples, 1024))
 92     test_y = []
 93     for i in xrange(numSamples):
 94         filename = testingFileList[i]
 95 
 96         # get train_x
 97         test_x[i, :] = img2vector(dataSetDir + 'testDigits/%s' % filename)
 98 
 99         # get label from file name such as "1_18.txt"
100         label = int(filename.split('_')[0]) # return 1
101         test_y.append(label)
102 
103     return train_x, train_y, test_x, test_y
104 
105 # 手寫識別主流程
106 def testHandWritingClass():
107     # # step 1: 加載數據
108     print "step 1: load data..."
109     train_x, train_y, test_x, test_y = loadDataSet()
110 
111     # # step 2: 模型訓練.
112     print "step 2: training..."
113     pass
114 
115     # # step 3: 測試
116     print "step 3: testing..."
117     numTestSamples = test_x.shape[0]
118     matchCount = 0
119     for i in xrange(numTestSamples):
120         predict = kNNClassify(test_x[i], train_x, train_y, 3)
121         if predict == test_y[i]:
122             matchCount += 1
123     accuracy = float(matchCount) / numTestSamples
124 
125     # # step 4: 輸出結果
126     print "step 4: show the result..."
127     print 'The classify accuracy is: %.2f%%' % (accuracy * 100)

 

KNNTest.py

#!/usr/bin/python
# coding=utf-8

import KNN
KNN.testHandWritingClass()

 

測試結果:

相關文章
相關標籤/搜索