需求:python
利用一個手寫數字「先驗數據」集,使用knn算法來實現對手寫數字的自動識別;git
先驗數據(訓練數據)集:算法
♦數據維度比較大,樣本數比較多。app
♦ 數據集包括數字0-9的手寫體。函數
♦每一個數字大約有200個樣本。測試
♦每一個樣本保持在一個txt文件中。spa
♦手寫體圖像自己的大小是32x32的二值圖,轉換到txt文件保存後,內容也是32x32個數字,0或者1,以下:rest
數據集壓縮包解壓後有兩個目錄:(將這兩個目錄文件夾拷貝的項目路徑下E:/KNNCase/digits/)code
♦目錄trainingDigits存放的是大約2000個訓練數據orm
♦目錄testDigits存放大約900個測試數據。
模型分析:
一、手寫體由於每一個人,甚至每次寫的字都不會徹底精確一致,因此,識別手寫體的關鍵是「類似度」
二、既然是要求樣本之間的類似度,那麼,首先須要將樣本進行抽象,將每一個樣本變成一系列特徵數據(即特徵向量)
三、手寫體在直觀上就是一個個的圖片,而圖片是由上述圖示中的像素點來描述的,樣本的類似度其實就是像素的位置和顏色之間的組合的類似度
四、所以,將圖片的像素按照固定順序讀取到一個個的向量中,便可很好地表示手寫體樣本
五、抽象出了樣本向量,及類似度計算模型,便可應用KNN來實現
python實現:
新建一個KNN.py腳本文件,文件裏面包含四個函數:
1) 一個用來生成將每一個樣本的txt文件轉換爲對應的一個向量,
2) 一個用來加載整個數據集,
3) 一個實現kNN分類算法。
4) 最後就是實現加載、測試的函數。
1 #!/usr/bin/python 2 # coding=utf-8 3 ######################################### 4 # kNN: k Nearest Neighbors 5 6 # 參數: inX: vector to compare to existing dataset (1xN) 7 # dataSet: size m data set of known vectors (NxM) 8 # labels: data set labels (1xM vector) 9 # k: number of neighbors to use for comparison 10 11 # 輸出: 多數類 12 ######################################### 13 14 from numpy import * 15 import operator 16 import os 17 18 19 # KNN分類核心方法 20 def kNNClassify(newInput, dataSet, labels, k): 21 numSamples = dataSet.shape[0] # shape[0]表明行數 22 23 # # step 1: 計算歐式距離 24 # tile(A, reps): 將A重複reps次來構造一個矩陣 25 # the following copy numSamples rows for dataSet 26 diff = tile(newInput, (numSamples, 1)) - dataSet # Subtract element-wise 27 squaredDiff = diff ** 2 # squared for the subtract 28 squaredDist = sum(squaredDiff, axis = 1) # sum is performed by row 29 distance = squaredDist ** 0.5 30 31 # # step 2: 對距離排序 32 # argsort()返回排序後的索引 33 sortedDistIndices = argsort(distance) 34 35 classCount = {} # 定義一個空的字典 36 for i in xrange(k): 37 # # step 3: 選擇k個最小距離 38 voteLabel = labels[sortedDistIndices[i]] 39 40 # # step 4: 計算類別的出現次數 41 # when the key voteLabel is not in dictionary classCount, get() 42 # will return 0 43 classCount[voteLabel] = classCount.get(voteLabel, 0) + 1 44 45 # # step 5: 返回出現次數最多的類別做爲分類結果 46 maxCount = 0 47 for key, value in classCount.items(): 48 if value > maxCount: 49 maxCount = value 50 maxIndex = key 51 52 return maxIndex 53 54 # 將圖片轉換爲向量 55 def img2vector(filename): 56 rows = 32 57 cols = 32 58 imgVector = zeros((1, rows * cols)) 59 fileIn = open(filename) 60 for row in xrange(rows): 61 lineStr = fileIn.readline() 62 for col in xrange(cols): 63 imgVector[0, row * 32 + col] = int(lineStr[col]) 64 65 return imgVector 66 67 # 加載數據集 68 def loadDataSet(): 69 # # step 1: 讀取訓練數據集 70 print "---Getting training set..." 71 dataSetDir = 'E:/KNNCase/digits/' 72 trainingFileList = os.listdir(dataSetDir + 'trainingDigits') # 加載測試數據 73 numSamples = len(trainingFileList) 74 75 train_x = zeros((numSamples, 1024)) 76 train_y = [] 77 for i in xrange(numSamples): 78 filename = trainingFileList[i] 79 80 # get train_x 81 train_x[i, :] = img2vector(dataSetDir + 'trainingDigits/%s' % filename) 82 83 # get label from file name such as "1_18.txt" 84 label = int(filename.split('_')[0]) # return 1 85 train_y.append(label) 86 87 # # step 2:讀取測試數據集 88 print "---Getting testing set..." 89 testingFileList = os.listdir(dataSetDir + 'testDigits') # load the testing set 90 numSamples = len(testingFileList) 91 test_x = zeros((numSamples, 1024)) 92 test_y = [] 93 for i in xrange(numSamples): 94 filename = testingFileList[i] 95 96 # get train_x 97 test_x[i, :] = img2vector(dataSetDir + 'testDigits/%s' % filename) 98 99 # get label from file name such as "1_18.txt" 100 label = int(filename.split('_')[0]) # return 1 101 test_y.append(label) 102 103 return train_x, train_y, test_x, test_y 104 105 # 手寫識別主流程 106 def testHandWritingClass(): 107 # # step 1: 加載數據 108 print "step 1: load data..." 109 train_x, train_y, test_x, test_y = loadDataSet() 110 111 # # step 2: 模型訓練. 112 print "step 2: training..." 113 pass 114 115 # # step 3: 測試 116 print "step 3: testing..." 117 numTestSamples = test_x.shape[0] 118 matchCount = 0 119 for i in xrange(numTestSamples): 120 predict = kNNClassify(test_x[i], train_x, train_y, 3) 121 if predict == test_y[i]: 122 matchCount += 1 123 accuracy = float(matchCount) / numTestSamples 124 125 # # step 4: 輸出結果 126 print "step 4: show the result..." 127 print 'The classify accuracy is: %.2f%%' % (accuracy * 100)
KNNTest.py
#!/usr/bin/python # coding=utf-8 import KNN KNN.testHandWritingClass()
測試結果: