譯者按: 機器學習原來很簡單啊,不妨動手試試!javascript
原文: Machine Learning with JavaScript : Part 2php
譯者: Fundebughtml
爲了保證可讀性,本文采用意譯而非直譯。另外,本文版權歸原做者全部,翻譯僅用於學習。另外,咱們修正了原文代碼中的錯誤java
上圖使用plot.ly所畫。node
上次咱們用JavaScript實現了線性規劃,此次咱們來聊聊KNN算法。git
KNN是k-Nearest-Neighbours的縮寫,它是一種監督學習算法。KNN算法能夠用來作分類,也能夠用來解決迴歸問題。github
GitHub倉庫: machine-learning-with-js算法
簡單地說,KNN算法由那離本身最近的K個點來投票決定待分類數據歸爲哪一類。shell
若是待分類的數據有這些鄰近數據,NY: 7, NJ: 0, IN: 4,即它有7個NY鄰居,0個NJ鄰居,4個IN鄰居,則這個數據應該歸類爲NY。npm
假設你在郵局工做,你的任務是爲郵遞員分配信件,目標是最小化到各個社區的投遞旅程。不妨假設一共有7個街區。這就是一個實際的分類問題。你須要將這些信件分類,決定它屬於哪一個社區,好比上東城、曼哈頓下城等。
最壞的方案是隨意分配信件分配給郵遞員,這樣每一個郵遞員會拿到各個社區的信件。
最佳的方案是根據信件地址進行分類,這樣每一個郵遞員只須要負責鄰近社區的信件。
也許你是這樣想的:"將鄰近3個街區的信件分配給同一個郵遞員"。這時,鄰近街區的個數就是k。你能夠不斷增長k,直到得到最佳的分配方案。這個k就是分類問題的最佳值。
像上次同樣,咱們將使用mljs的KNN模塊ml-knn來實現。
每個機器學習算法都須要數據,此次我將使用IRIS數據集。其數據集包含了150個樣本,都屬於鳶尾屬下的三個亞屬,分別是山鳶尾、變色鳶尾和維吉尼亞鳶尾。四個特徵被用做樣本的定量分析,它們分別是花萼和花瓣的長度和寬度。
$ npm install ml-knn@2.0.0 csvtojson prompt
ml-knn: k-Nearest-Neighbours模塊,不一樣版本的接口可能不一樣,這篇博客使用了2.0.0
csvtojson: 用於將CSV數據轉換爲JSON
prompt: 在控制檯輸入輸出數據
IRIS數據集由加州大學歐文分校提供。
curl https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data > iris.csv
假設你已經初始化了一個NPM項目,請在index.js中輸入如下內容:
const KNN = require('ml-knn'); const csv = require('csvtojson'); const prompt = require('prompt'); var knn; const csvFilePath = 'iris.csv'; // 數據集 const names = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth', 'type']; let seperationSize; // 分割訓練和測試數據 let data = [], X = [], y = []; let trainingSetX = [], trainingSetY = [], testSetX = [], testSetY = [];
seperationSize用於分割數據和測試數據
使用csvtojson模塊的fromFile方法加載數據:
csv( { noheader: true, headers: names }) .fromFile(csvFilePath) .on('json', (jsonObj) => { data.push(jsonObj); // 將數據集轉換爲JS對象數組 }) .on('done', (error) => { seperationSize = 0.7 * data.length; data = shuffleArray(data); dressData(); });
咱們將seperationSize設爲樣本數目的0.7倍。注意,若是訓練數據集過小的話,分類效果將變差。
因爲數據集是根據種類排序的,因此須要使用shuffleArray函數對數據進行混淆,這樣才能方便分割出訓練數據。這個函數的定義請參考StackOverflow的提問How to randomize (shuffle) a JavaScript array?:
function shuffleArray(array) { for (var i = array.length - 1; i > 0; i--) { var j = Math.floor(Math.random() * (i + 1)); var temp = array[i]; array[i] = array[j]; array[j] = temp; } return array; }
數據集中每一條數據能夠轉換爲一個JS對象:
{ sepalLength: ‘5.1’, sepalWidth: ‘3.5’, petalLength: ‘1.4’, petalWidth: ‘0.2’, type: ‘Iris-setosa’ }
在使用KNN算法訓練數據以前,須要對數據進行這些處理:
將屬性(sepalLength, sepalWidth,petalLength,petalWidth)由字符串轉換爲浮點數. (parseFloat)
將分類 (type)用數字表示
function dressData() { let types = new Set(); data.forEach((row) => { types.add(row.type); }); let typesArray = [...types]; data.forEach((row) => { let rowArray, typeNumber; rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4); typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number) X.push(rowArray); y.push(typeNumber); }); trainingSetX = X.slice(0, seperationSize); trainingSetY = y.slice(0, seperationSize); testSetX = X.slice(seperationSize); testSetY = y.slice(seperationSize); train(); }
function train() { knn = new KNN(trainingSetX, trainingSetY, { k: 7 }); test(); }
train方法須要2個必須的參數: 輸入數據,即花萼和花瓣的長度和寬度;實際分類,即山鳶尾、變色鳶尾和維吉尼亞鳶尾。另外,第三個參數是可選的,用於提供調整KNN算法的內部參數。我將k參數設爲7,其默認值爲5。
訓練好模型以後,就可使用測試數據來檢查準確性了。咱們主要對預測出錯的個數比較感興趣。
function test() { const result = knn.predict(testSetX); const testSetLength = testSetX.length; const predictionError = error(result, testSetY); console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`); predict(); }
比較預測值與真實值,就能夠獲得出錯個數:
function error(predicted, expected) { let misclassifications = 0; for (var index = 0; index < predicted.length; index++) { if (predicted[index] !== expected[index]) { misclassifications++; } } return misclassifications; }
任意輸入屬性值,就能夠獲得預測值
function predict() { let temp = []; prompt.start(); prompt.get(['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], function(err, result) { if (!err) { for (var key in result) { temp.push(parseFloat(result[key])); } console.log(`With ${temp} -- type = ${knn.predict(temp)}`); } }); }
完整的程序index.js是這樣的:
const KNN = require('ml-knn'); const csv = require('csvtojson'); const prompt = require('prompt'); var knn; const csvFilePath = 'iris.csv'; // 數據集 const names = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth', 'type']; let seperationSize; // 分割訓練和測試數據 let data = [], X = [], y = []; let trainingSetX = [], trainingSetY = [], testSetX = [], testSetY = []; csv( { noheader: true, headers: names }) .fromFile(csvFilePath) .on('json', (jsonObj) => { data.push(jsonObj); // 將數據集轉換爲JS對象數組 }) .on('done', (error) => { seperationSize = 0.7 * data.length; data = shuffleArray(data); dressData(); }); function dressData() { let types = new Set(); data.forEach((row) => { types.add(row.type); }); let typesArray = [...types]; data.forEach((row) => { let rowArray, typeNumber; rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4); typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number) X.push(rowArray); y.push(typeNumber); }); trainingSetX = X.slice(0, seperationSize); trainingSetY = y.slice(0, seperationSize); testSetX = X.slice(seperationSize); testSetY = y.slice(seperationSize); train(); } // 使用KNN算法訓練數據 function train() { knn = new KNN(trainingSetX, trainingSetY, { k: 7 }); test(); } // 測試訓練的模型 function test() { const result = knn.predict(testSetX); const testSetLength = testSetX.length; const predictionError = error(result, testSetY); console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`); predict(); } // 計算出錯個數 function error(predicted, expected) { let misclassifications = 0; for (var index = 0; index < predicted.length; index++) { if (predicted[index] !== expected[index]) { misclassifications++; } } return misclassifications; } // 根據輸入預測結果 function predict() { let temp = []; prompt.start(); prompt.get(['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], function(err, result) { if (!err) { for (var key in result) { temp.push(parseFloat(result[key])); } console.log(`With ${temp} -- type = ${knn.predict(temp)}`); } }); } // 混淆數據集的順序 function shuffleArray(array) { for (var i = array.length - 1; i > 0; i--) { var j = Math.floor(Math.random() * (i + 1)); var temp = array[i]; array[i] = array[j]; array[j] = temp; } return array; }
在控制檯執行node index.js
$ node index.js
輸出以下:
Test Set Size = 45 and number of Misclassifications = 2 prompt: Sepal Length: 1.7 prompt: Sepal Width: 2.5 prompt: Petal Length: 0.5 prompt: Petal Width: 3.4 With 1.7,2.5,0.5,3.4 -- type = 2
歡迎加入咱們Fundebug的全棧BUG監控交流羣: 622902485。
版權聲明:
轉載時請註明做者Fundebug以及本文地址:
https://blog.fundebug.com/2017/07/10/javascript-machine-learning-knn/