本文主要是用kNN算法對字母圖片進行特徵提取,分類識別。內容以下:git
1、kNN算法介紹算法
K近鄰(kNN,k-NearestNeighbor)分類算法是機器學習算法中最簡單的方法之一。所謂K近鄰,就是k個最近的鄰居的意思,說的是每一個樣本均可以用它最接近的k個鄰居來表明。咱們將樣本分爲訓練樣本和測試樣本。對一個測試樣本 t 進行分類,kNN的作法是先計算樣本 t 到全部訓練樣本的歐氏距離,而後從中找出k個距離最短的訓練樣本,用這k個訓練樣本中出現次數最多的類別表示樣本 t 的類別。數組
歐式距離的計算公式:app
假設每一個樣本有兩個特徵值,如 A :(a1,b1)B:(a2,b2) 則AB的歐式距離爲 機器學習
舉個例子:根據下圖前四位同窗的成績和等級,預測第五位小白同窗的等級。函數
咱們能夠看出:語文和數學成績是一個學生的特徵,等級是一個學生的類別。工具
前四位同窗是訓練樣本,第五位同窗是測試樣本。咱們如今用kNN算法來預測第五位同窗的等級,k取3。學習
按照上面歐式距離公式咱們能夠計算測試
d(5-1)== 7 d(5-2)== 30 網站
d(5-3)== 6 d(5-4)== 19.2
由於 k 取 3,因此咱們尋找3個距離最近的樣本,即編號爲3,1,4的同窗,他們的等級分別是 B,B,A。 這三個樣本的分類中,出現了2次B,一次A,B出現次數最多,因此5號同窗的等級可能爲B
經常使用Python模塊
NumPy:NumPy是Python的一種開源的數值計算擴展。這種工具可用來存儲和處理大型矩陣,比Python自身的嵌套列表結構要高效的多。
PIL:Python Imaging Library,是Python平臺事實上的圖像處理標準庫,功能很是強大,API也簡單易用。但PIL包主要針對Python2,不兼容Python3,因此在Python3中使用Pillow,後者是大牛根據PIL移植過來的,二者用法相同。
上面兩個Python庫均可以經過pip進行安裝。
pip3 install [name]
還有就是Python 自帶標準庫:shutil模塊提供了大量的文件的高級操做,特別針對文件拷貝和刪除,主要功能爲目錄和文件操做以及壓縮操做。operator模塊是Python 的運算符庫,os 模塊是Python的系統的和操做系統相關的函數庫。
2、對圖片進行特徵提取
一、採集手寫字母的圖片素材
有許多提供機器學習數據集的網站,如知乎上的整理 https://www.zhihu.com/question/63383992/answer/222718972 我搜集到的手寫字母圖片資源以下 連接:https://pan.baidu.com/s/1pM329fl 密碼:i725 其中by_class.zip 壓縮包是已經分類好的圖片樣本,能夠直接下載使用
二、提取圖片素材的特徵
最簡單的作法是將圖片轉換爲由0 和1 組成的txt 文件,如
轉換代碼以下:
1 import os 2 import shutil 3 from PIL import Image 4 5 6 # image_file_prefix png圖片所在的文件夾 7 # file_name png png圖片的名字 8 # txt_path_prefix 轉換後txt 文件所在的文件夾 9 def generate_txt_image(image_file_prefix, file_name, txt_path_prefix): 10 """將圖片處理成只有0 和 1 的txt 文件""" 11 # 將png圖片轉換成二值圖並截取四周多餘空白部分 12 image_path = os.path.join(image_file_prefix, file_name) 13 # convert('L') 將圖片轉爲灰度圖 convert('1') 將圖片轉爲二值圖 14 img = Image.open(image_path, 'r').convert('1').crop((32, 32, 96, 96)) 15 # 指定轉換後的寬 高 16 width, height = 32, 32 17 img.thumbnail((width, height), Image.ANTIALIAS) 18 # 將二值圖片轉換爲0 1,存儲到二位數組arr中 19 arr = [] 20 for i in range(width): 21 pixels = [] 22 for j in range(height): 23 pixel = int(img.getpixel((j, i))) 24 pixel = 0 if pixel == 0 else 1 25 pixels.append(pixel) 26 arr.append(pixels) 27 28 # 建立txt文件(mac下使用os.mknod()建立文件須要root權限,這裏改用複製的方式) 29 text_image_file = os.path.join(txt_path_prefix, file_name.split('.')[0] + '.txt') 30 empty_txt_path = "/Users/beiyan/Downloads/empty.txt" 31 shutil.copyfile(empty_txt_path, text_image_file) 32 33 # 寫入文件 34 with open(text_image_file, 'w') as text_file_object: 35 for line in arr: 36 for e in line: 37 text_file_object.write(str(e)) 38 text_file_object.write("\n")
將全部素材轉換爲 txt 後,分爲兩部分:訓練樣本 和 測試樣本。
3、kNN算法實現
一、將txt文件轉爲一維數組的方法:
1 def img2vector(filename, width, height): 2 """將txt文件轉爲一維數組""" 3 return_vector = np.zeros((1, width * height)) 4 fr = open(filename) 5 for i in range(height): 6 line = fr.readline() 7 for j in range(width): 8 return_vector[0, height * i + j] = int(line[j]) 9 return return_vector
二、對測試樣本進行kNN分類,返回測試樣本的類別:
1 import numpy as np 2 import os 3 import operator 4 5 6 # test_set 單個測試樣本 7 # train_set 訓練樣本二維數組 8 # labels 訓練樣本對應的分類 9 # k k值 10 def classify(test_set, train_set, labels, k): 11 """對測試樣本進行kNN分類,返回測試樣本的類別""" 12 # 獲取訓練樣本條數 13 train_size = train_set.shape[0] 14 15 # 計算特徵值的差值並求平方 16 # tile(A,(m,n)),功能是將數組A行重複m次 列重複n次 17 diff_mat = np.tile(test_set, (train_size, 1)) - train_set 18 sq_diff_mat = diff_mat ** 2 19 20 # 計算歐式距離 存儲到數組 distances 21 sq_distances = sq_diff_mat.sum(axis=1) 22 distances = sq_distances ** 0.5 23 24 # 按距離由小到大排序對索引進行排序 25 sorted_index = distances.argsort() 26 27 # 求距離最短k個樣本中 出現最多的分類 28 class_count = {} 29 for i in range(k): 30 near_label = labels[sorted_index[i]] 31 class_count[near_label] = class_count.get(near_label, 0) + 1 32 sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True) 33 return sorted_class_count[0][0]
三、統計分類錯誤率
1 # train_data_path 訓練樣本文件夾 2 # test_data_path 測試樣本文件夾 3 # k k個最近鄰居 4 def get_error_rate(train_data_path, test_data_path, k): 5 """統計識別錯誤率""" 6 width, height = 32, 32 7 train_labels = [] 8 9 training_file_list = os.listdir(train_data_path) 10 train_size = len(training_file_list) 11 12 # 生成全爲0的訓練集數組 13 train_set = np.zeros((train_size, width * height)) 14 15 # 讀取訓練樣本 16 for i in range(train_size): 17 file = training_file_list[i] 18 file_name = file.split('.')[0] 19 label = str(file_name.split('_')[0]) 20 train_labels.append(label) 21 train_set[i, :] = img2vector(os.path.join(train_data_path, training_file_list[i]), width, height) 22 23 test_file_list = os.listdir(test_data_path) 24 # 識別錯誤的個數 25 error_count = 0.0 26 # 測試樣本的個數 27 test_count = len(test_file_list) 28 29 # 統計識別錯誤的個數 30 for i in range(test_count): 31 file = test_file_list[i] 32 true_label = file.split('.')[0].split('_')[0] 33 34 test_set = img2vector(os.path.join(test_data_path, test_file_list[i]), width, height) 35 test_label = classify(test_set, train_set, train_labels, k) 36 print(true_label, test_label) 37 if test_label != true_label: 38 error_count += 1.0 39 percent = error_count / float(test_count) 40 print("識別錯誤率是:{}".format(str(percent)))
上述完整代碼地址:https://gitee.com/beiyan/machine_learning/tree/master/knn
四、測試結果
訓練樣本: 0-9,a-z,A-Z 共62個字符,每一個字符選取120個訓練樣本 , 一共有7440 個訓練樣本。每一個字符選取20個測試樣本,一共1200個測試樣本。
嘗試改變條件,測得識別正確率以下:
4、kNN算法分析
由上部分結果可知:knn算法對於手寫字母的識別率並不理想。
緣由可能有如下幾個方面:
一、圖片特徵提取過於簡單,圖片邊緣較多空白,且圖片中字母的中心位置未必所有對應
二、由於英文有些字母大小寫比較類似,容易識別錯誤
三、樣本規模較小,每一個字符最多隻有300個訓練樣本,真正的訓練須要海量數據
在後序的文章中嘗試用其餘學習算法提升分類識別率。各位道友有更好的意見也歡迎提出!