異常檢測是機器學習領域常見的應用場景,例如金融領域裏的信用卡欺詐,企業安全領域裏的非法入侵,IT運維裏預測設備的維護時間點等。咱們今天就來看看異常檢測的基本概念,算法,而後看看如何利用TensorflowJS來進行異常檢測。javascript
異常點是指數據中和其它點不同的點,異常檢測就是要找到這些點。一般有如下這些不一樣類型的異常:html
從數據維度的角度來看,異常也分爲單變量(univariate)和多變量異常(multivariate)。java
異常檢測的算法主要包括基於統計的算法和基於機器學習的算法。python
利用統計方法來進行異常檢測有兩種,第一種是參數化的,就是假定正常的數據是基於某種參數分佈的,那麼咱們能夠經過訓練數據估計出數據的分佈機率,那麼對於每個要分析的數據點都計算出該數據點在這個機率分佈下生成的機率。這個值越高,說明該數據是正常點的可能性就越大,該數值越低,就說明這個點就越有多是異常點。git
最多見的方式就是ZScore,假定數據符合正態分佈,ZScore計算數據點偏離均值多少個標準差。ZScore越大說明數據偏離均值越遠,那麼它是異常的機率就越高。github
非參數化的方法並不假定數據的先驗分佈,數據的分佈是從訓練數據中學習而來的。算法
其它還有一些統計方法諸如:spring
利用統計方法作異常檢測很是容易理解,計算效率也很好。可是這種方法存在一些挑戰:瀏覽器
從監督學習和非監督學習的角度來看,若是已經有了標記異常點的大量訓練數據,異常檢測能夠簡單的轉化爲分類問題,也就是數據分兩類,正常點和異常點。可是在現實中,每每很難找到大量標記好異常點的訓練數據,因此每每須要非監督學習來進行異常檢測。安全
利用數據的類似度來檢測異常的基本假設是,若是被檢測的數據和已有的數據類似度大,那麼它是正常數據的可能性就大。類似度的學習主要有基於距離的(KNN)和基於密度的(LOF)。
基於聚類的異常檢測的基本假設是,正常數據彙集在一塊兒,異常數據彙集在一塊兒。
DBSCAN是異常檢測經常使用的聚類方法。關於DBSCAN算法的介紹,你們能夠參考個人博客圖解機器學習
如上圖所示,DBSCAN能夠學習出正常聚類的中心點A,邊緣點BC以及異常點N。
可是DBSCAN對於各個超參數的設定很是敏感,利用該方法雖然不須要標記異常點,可是找到合適的超參數並不容易。
支持向量機(SVM)是一種監督學習的分類方法,單類支持向量機(OneClassSVM)是SVM的一種擴展,能夠用於非監督的檢測異常。
該算法能夠學習出正常點和異常點之間的邊界。
隔離森林(isolation forests)是檢測數據中異常值或新穎性的一種有效方法。這是一種基於二元決策樹的方法。
隔離森林的基本原則是異常值不多,並且與其餘觀測結果相差甚遠。爲了構建樹(訓練),算法從特徵空間中隨機選取一個特徵,並在最大值和最小值之間隨機選擇一個隨機分割值。這是針對訓練集中的全部觀察結果。爲了建造森林,樹木總體被平均化爲森林中的全部樹木。
而後,爲了預測,它將觀察與「節點」中的分裂值進行比較,該節點將具備兩個節點子節點,在該子節點上將進行另外一次隨機比較。由算法爲實例作出的「分裂」的數量被命名爲:「路徑長度」。正如預期的那樣,異常值的路徑長度將比其餘觀察值更短。
好了咱們瞭解了異常檢測的基本概念和方法,那麼如何利用深度學習來進行異常檢測呢?
雖然神經網絡的主要應用是監督學習,可是其實也能夠利用它來進行非監督學習,這裏咱們就須要瞭解自編碼器(Autoencoder)了。
自編碼器就是相似上圖的一個網絡,包含編碼和解碼兩個主要的部分,咱們利用訓練數據集對該網絡進行訓練,輸出的目標等於輸入的數據。也就是說咱們訓練了一個能夠重建輸入數據的深度神經網絡。那麼這樣作有什麼用能。
咱們能夠看出編碼的過程其實相似一個PCA的降維過程,就是通過編碼,找到數據中的主要成分,利用該主要成份可以重建原始數據,就好像數據壓縮和解壓縮的過程,用更少的數據來取代原始數據。對於通常的自編碼器的應用,訓練好的自編碼器不會所有用於構建網絡,通常是使用編碼的部分來進行數據的特徵提取,降維,以達到更有效的計算。
利用自編碼器,咱們假定正常數據經過自編碼器應該會還原,也就是輸入和輸出是同樣的,而對於異常數據,還原出來的數據和原始數據存在差別。基本假設就是還原出來的數據和輸入數據差別越小,那麼它是正常數據的可能性就越大,反之它是異常數據的可能性就越大。
下面咱們就來看一個利用自編碼器用tensorflowJS來檢測信用卡欺詐數據的例子。數據集來自Kaggle,考慮到TensorflowJS在瀏覽器中的性能問題,我對原始數據取樣10000條記錄來演示。
該數據通過kaggle處理,包含Time交易時間,Amount交易數額,V1-V28是通過處理後的特徵,Class表示交易的類別,1爲欺詐交易。
async function loadData(path) { return await d3.csv(path); } const dataset = await loadData( "https://cdn.jsdelivr.net/gh/gangtao/datasets@master/csv/creditcard_sample_raw.csv" );
function standarize(val, min, max) { return (val - min) / (max - min); } function prepare(dataset) { const processedDataset = dataset.map(item => { const obj = {}; for (let i = 1; i < 29; i++) { const key = `V${i}`; obj[key] = parseFloat(item[key]); } obj["Class"] = item["Class"]; obj["Time"] = parseFloat(item["Time"]); obj["Amount"] = parseFloat(item["Amount"]); return obj; }); const timeMax = d3.max(processedDataset.map(i => i.Time)); const timeMin = d3.min(processedDataset.map(i => i.Time)); const amountMax = d3.max(processedDataset.map(i => i.Amount)); const amountMin = d3.min(processedDataset.map(i => i.Amount)); processedDataset.forEach(item => { item.stdTime = standarize(item.Time, timeMax, timeMin); item.stdAmount = standarize(item.Amount, amountMax, amountMin); }); return processedDataset; } const preparedDataset = prepare(dataset);
在數據預處理階段咱們對Time和Amount作標準化處理使它的值在(0-1)之間。
function makeTrainData(dataset) { console.log(dataset.length); const normalData = dataset.filter(item => item.Class == "0"); const anomalData = dataset.filter(item => item.Class == "1"); const sliceIndex = normalData.length*0.8; const normalTrainData = normalData.slice(0,sliceIndex); const normalTestData = normalData.slice(sliceIndex+1, normalData.length); console.log(normalData.length); const trainData = { x: [], y: [] }; normalTrainData.forEach(item => { const row = []; for (let i = 1; i < 29; i++) { const key = `V${i}`; row.push(item[key]); } row.push(item["stdAmount"]); row.push(item["stdTime"]); trainData.x.push(row); trainData.y.push(row); }); const testData = normalTestData.map(item => { const row = []; for (let i = 1; i < 29; i++) { const key = `V${i}`; row.push(item[key]); } row.push(item["stdAmount"]); row.push(item["stdTime"]); return row; }); const testAnomalData = anomalData.map(item => { const row = []; for (let i = 1; i < 29; i++) { const key = `V${i}`; row.push(item[key]); } row.push(item["stdAmount"]); row.push(item["stdTime"]); return row; }); return [trainData, testData, testAnomalData]; } const [trainData, testData, testAnomalData] = makeTrainData(preparedDataset);
咱們選擇80%的正常數據作訓練,另外20%的正常交易數據和全部的異常交易數據作測試。
function buildModel() { const model = tf.sequential(); //encoder Layer const encoder = tf.layers.dense({ inputShape: [INPUT_NUM], units: FEATURE_NUM, activation: "tanh" }); model.add(encoder); const encoder_hidden = tf.layers.dense({ inputShape: [FEATURE_NUM], units: HIDDEN_NUM, activation: "relu" }); model.add(encoder_hidden); //decoder Layer const decoder_hidden = tf.layers.dense({ units: HIDDEN_NUM, activation: "tanh" }); model.add(decoder_hidden); //decoder Layer const decoder = tf.layers.dense({ units: INPUT_NUM, activation: "relu" }); model.add(decoder); //compile const adam = tf.train.adam(0.005); model.compile({ optimizer: adam, loss: tf.losses.meanSquaredError }); return model; } async function watchTraining() { const metrics = ["loss", "val_loss", "acc", "val_acc"]; const container = { name: "show.fitCallbacks", tab: "Training", styles: { height: "1000px" } }; const callbacks = tfvis.show.fitCallbacks(container, metrics); return train(model, data, callbacks); } async function trainBatch(data, model) { const metrics = ["loss", "val_loss", "acc", "val_acc"]; const container = { name: "show.fitCallbacks", tab: "Training", styles: { height: "1000px" } }; const callbacks = tfvis.show.fitCallbacks(container, metrics); console.log("training start!"); tfvis.visor(); // Save the model // const saveResults = await model.save('downloads://creditcard-model'); const epochs = config.epochs; const results = []; const xs = tf.tensor2d(data.x); const ys = tf.tensor2d(data.y); const history = await model.fit(xs, ys, { batchSize: config.batchSize, epochs: config.epochs, validationSplit: 0.2, callbacks: callbacks }); console.log("training complete!"); return history; } const model = buildModel(); model.summary(); const history = await trainBatch(trainData, model);
咱們的自編碼器的模型以下:
_________________________________________________________________ Layer (type) Output shape Param # ================================================================= dense_Dense1 (Dense) [null,16] 496 _________________________________________________________________ dense_Dense2 (Dense) [null,8] 136 _________________________________________________________________ dense_Dense3 (Dense) [null,8] 72 _________________________________________________________________ dense_Dense4 (Dense) [null,30] 270 ================================================================= Total params: 974 Trainable params: 974 Non-trainable params: 0
前兩層是編碼,後兩層是解碼。
自編碼器模型訓練好了之後咱們就能夠用它來分析異常,咱們對測試數據的正常交易記錄和異常交易記錄用該模型預測,理論上正常交易的輸出更接近原始值,而異常交易記錄應該偏離原始值比較多,咱們利用歐式距離來分析自編碼器的輸出結果。
async function distance(a, b ){ const axis = 1; const result = tf.pow(tf.sum(tf.pow(a.sub(b), 2), axis), 0.5); return result.data(); } async function predict(model, input) { const prediction = await model.predict(tf.tensor(input)); return prediction; } const predictNormal = await predict(model, testData); const predictAnomal = await predict(model, testAnomalData); const distanceNormal = await distance(tf.tensor(testData), predictNormal); const distanceAnomal = await distance(tf.tensor(testAnomalData), predictAnomal); const resultData = []; distanceNormal.forEach(item => { const obj = {}; obj.type = "normal"; obj.value = item; obj.index = Math.random(); resultData.push(obj); }) distanceAnomal.forEach(item => { const obj = {}; obj.type = "outlier"; obj.value = item; obj.index = Math.random(); resultData.push(obj); })
測試結果以下圖:
上圖綠色是異常交易,藍色是正常交易。由於正常交易的數量較多,咱們可能看不太清楚,咱們分別顯示以下圖:
咱們看到異常交易的自編碼器輸出和原始結果的距離都是大於10的,而絕大部分正常交易集中在10如下的區域,若是咱們以10爲伐值,應該能夠找到大部分的異常交易,固然會有大量的正常交易誤報。也就是該模型是沒法作到徹底的分辨正常和異常交易的。
完整的代碼見個人Codepen
本文介紹了各類異常檢測的主要方法,不管是統計方法,機器學習的方法仍是深度學習的方法,其中主要問題都是對於伐值或者參數的設置。
對於統計方法,須要肯定究竟生成機率多少的事件是異常是百年一遇的洪水是異常,仍是千年一遇的洪水是異常?
對於各類監督學習,咱們每每缺少異常點的標記,而對於非監督學習,調整各類參數會對異常點的判斷有很大的影響。
對於基於自編碼器的方法而言,咱們看到,咱們利用利用自編碼器的輸出和輸入的差別來判斷該事件是否爲異常事件,然而究竟偏離多少來定義爲異常,仍然須要用戶來指定。
咱們但願的徹底經過數據和算法來自動發現異常仍然是一個比較困難的問題。