對於數據分析而言,咱們老是極力找數學模型來描述數據發生的規律, 有的數據咱們在二維空間就能夠描述,有的數據則須要映射到更高維的空間。數據表現出來的分佈多是徹底離散的,也多是彙集成堆的,那麼機器學習的任務就是讓計算機本身在數據中學習到數據的規律。那麼這個規律一般是能夠用一些函數來描述,函數多是線性的,也多是非線性的,怎麼找到這些函數,是機器學習的首要問題。java
本篇博客嘗試用梯度降低法,找到線性函數的參數,來擬合一個數據集。dom
假設咱們有以下函數,其中x是一個三個維度,機器學習
寫一個java程序來,隨機產生100筆數據做爲訓練集。函數
Random random = new Random(); double[] results = new double[100]; double[][] features = new double[100][3]; for (int i = 0; i < 100; i++) { for (int j = 0; j < features[i].length; j++) { features[i][j] = random.nextDouble(); } results[i] = 3 * features[i][0] + 4 * features[i][1] + 5 * features[i][2] + 10; }
上面的程序中results就是函數的值,features的第二維就是隨機產生的3個x。學習
有了訓練集,咱們的任務就變成了如何求出3個各類的係數三、四、5,以及偏移量10,係數和偏移量能夠取任意值,那麼咱們就獲得了一個函數集,任務轉化一下就變成了找出一個函數做用於訓練集以後,與真實值的偏差最小,如何評判偏差的大小呢?咱們須要定義一個函數來評判,那麼給這個函數取一個名字,叫損失函數。這裏,損失函數定義爲,其中爲真實值,問題就轉化爲在訓練集中求以下函數:設計
如何求這個函數的極小值呢?若是咱們計算能力無限大,直接窮舉就完了,可是這不是高效的辦法,這時候就說的了梯度降低法,咱們來看看數學裏對梯度的定義。code
在微積分裏面,對多元函數的參數求∂偏導數,把求得的各個參數的偏導數以向量的形式寫出來,就是梯度。好比函數f(x,y), 分別對x,y求偏導數,求得的梯度向量就是(∂f/∂x, ∂f/∂y)T,簡稱grad f(x,y)或者▽f(x,y)。blog
梯度告訴咱們兩件事情:博客
一、函數增大的方向數據分析
二、咱們走向增大的方向,應該走多大步幅
求極小值,咱們反方向走便可,加個負號,可是這個步幅有個問題,若是過大,參數就直接飛出去了,就很難在找到最小值,若是過小,則頗有可能卡在局部極小值的地方。因此,咱們設計了一個係數來調節步幅,咱們叫它學習速率learningRate。
好了,爲了好描述,咱們把上面的函數泛化一下,表示成以下公式:
損失函數對每一個參數求偏導數,根據偏導數值,固然求導的過程須要用到鏈式法則,,這裏咱們直接給出參數更新公式以下:
對於BGD(批量梯度降低法):
對於SGD(隨機梯度降低法),SGD與BGD不一樣的是每筆數據,咱們都更新一次參數,效率比較低下。公式和上面相似,去掉求和符號和除以N便可。
下面是具體的代碼實現
import java.util.Random; public class LinearRegression { public static void main(String[] args) { // y=3*x1+4*x2+5*x3+10 Random random = new Random(); double[] results = new double[100]; double[][] features = new double[100][3]; for (int i = 0; i < 100; i++) { for (int j = 0; j < features[i].length; j++) { features[i][j] = random.nextDouble(); } results[i] = 3 * features[i][0] + 4 * features[i][1] + 5 * features[i][2] + 10; } double[] parameters = new double[] { 1.0, 1.0, 1.0, 1.0 }; double learningRate = 0.01; for (int i = 0; i < 30; i++) { SGD(features, results, learningRate, parameters); } parameters = new double[] { 1.0, 1.0, 1.0, 1.0 }; System.out.println("=========================="); for (int i = 0; i < 3000; i++) { BGD(features, results, learningRate, parameters); } } private static void SGD(double[][] features, double[] results, double learningRate, double[] parameters) { for (int j = 0; j < results.length; j++) { double gradient = (parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2] + parameters[3] - results[j]) * features[j][0]; parameters[0] = parameters[0] - 2 * learningRate * gradient; gradient = (parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2] + parameters[3] - results[j]) * features[j][1]; parameters[1] = parameters[1] - 2 * learningRate * gradient; gradient = (parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2] + parameters[3] - results[j]) * features[j][2]; parameters[2] = parameters[2] - 2 * learningRate * gradient; gradient = (parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2] + parameters[3] - results[j]); parameters[3] = parameters[3] - 2 * learningRate * gradient; } double totalLoss = 0; for (int j = 0; j < results.length; j++) { totalLoss = totalLoss + Math.pow((parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2] + parameters[3] - results[j]), 2); } System.out.println(parameters[0] + " " + parameters[1] + " " + parameters[2] + " " + parameters[3]); System.out.println("totalLoss:" + totalLoss); } private static void BGD(double[][] features, double[] results, double learningRate, double[] parameters) { double sum = 0; for (int j = 0; j < results.length; j++) { sum = sum + (parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2] + parameters[3] - results[j]) * features[j][0]; } double updateValue = 2 * learningRate * sum / results.length; parameters[0] = parameters[0] - updateValue; sum = 0; for (int j = 0; j < results.length; j++) { sum = sum + (parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2] + parameters[3] - results[j]) * features[j][1]; } updateValue = 2 * learningRate * sum / results.length; parameters[1] = parameters[1] - updateValue; sum = 0; for (int j = 0; j < results.length; j++) { sum = sum + (parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2] + parameters[3] - results[j]) * features[j][2]; } updateValue = 2 * learningRate * sum / results.length; parameters[2] = parameters[2] - updateValue; sum = 0; for (int j = 0; j < results.length; j++) { sum = sum + (parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2] + parameters[3] - results[j]); } updateValue = 2 * learningRate * sum / results.length; parameters[3] = parameters[3] - updateValue; double totalLoss = 0; for (int j = 0; j < results.length; j++) { totalLoss = totalLoss + Math.pow((parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2] + parameters[3] - results[j]), 2); } System.out.println(parameters[0] + " " + parameters[1] + " " + parameters[2] + " " + parameters[3]); System.out.println("totalLoss:" + totalLoss); } }
運行結果以下:
一樣是更新3000次參數。
一、SGD結果:
參數分別爲:3.087332784857909 、4.075233812033048 、5.0602082834888九、 9.89116046652793
totalLoss:0.13515381461776949
二、BGD結果:
參數分別爲:3.0819123489025344 、4.06414515146140三、5.04686257152001九、 9.899847277313173
totalLoss:0.1050937019067582
能夠看出,BGD有更好的表現。
快樂源於分享。
此博客乃做者原創, 轉載請註明出處