在本教程中,你能夠學習梯度降低法(gradient descent algorithm)是如何工做的,並使用 python 從頭開始實現它。首先咱們看看線性迴歸(linear regression)是什麼,而後定義損失函數(loss function)。咱們學習梯度降低算法是如何工做的,最後咱們將在給定的數據集上實現它並進行預測。git
本文是 這個視頻 的文字版,若是你更喜歡視頻就觀看它吧!github
在統計學裏,線性迴歸是一種對一個因變量和一個或多個自變量之間的關係進行建模的線性方法。設 X 是自變量,Y 是因變量。咱們將爲這兩個變量定義以下的線性關係:算法
這是你在高中時就學過的直線的方程。m 是斜率,c 是 y 軸的截距。今天,咱們將用給定的數據集使用這個方程來訓練咱們的模型,並對任意給定的 X 預測對應的 Y。咱們今天的挑戰是肯定 m 和 c 的值,使得這些值對應的直線是最佳擬合線或者偏差是最小的。c#
損失(loss)是咱們預測的 m 和 c 的偏差。咱們的目標是最小化這個偏差以得到 m 和 c 的最精確的值。咱們將使用均方偏差函數(the Mean Squared Error function)來計算損失。這個函數有三個步驟:機器學習
這裏 是實際值,
是預測值。讓咱們替換
的值:函數
也就是說咱們把偏差平方而後求出了平均值,也就是均方差這個名字的由來。如今咱們已經定義了損失函數,讓咱們進入最有意思的部分:最小化它並找到 m 和 c。學習
梯度降低是一種尋找函數最小值的迭代優化算法。在這裏,這個函數就是咱們的損失函數。優化
想象有一個山谷和一個想要到達山谷的底部但沒有方向感的人。他沿着斜坡走下去,當斜坡很陡時,他走一大步,當斜坡不太陡時,他走一小步。他根據他目前的位置來決定他的下一個位置,而且當他到達谷底,也就是他的目標時就會停下來。ui
讓咱們嘗試將梯度降低應用到 m 和 c,並一步一步逼近它:
是對 m 的偏導數的值。一樣,讓咱們求出對 c 的偏導數
:
如今回到咱們的類比,m 能夠看做是人的當前位置。D 至關於坡度,L 是他移動的速度。如今咱們用上面的等式計算的 m 的新值將是他的下一個位置, 是他的下一步的大小。當坡度越陡(D 越大)時,他走更長一步;當坡度越小(D越小)時,他走更短一步。最後他到達了谷底,對應於咱們的損失 = 0。
如今有了 m 和 c 的最佳值,咱們的模型就能夠進行預測了!
如今讓咱們把上面的一切都轉換成代碼,看看咱們的模型起做用!
# Making the imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (12.0, 9.0)
# Preprocessing Input data
data = pd.read_csv('data.csv')
X = data.iloc[:, 0]
Y = data.iloc[:, 1]
plt.scatter(X, Y)
plt.show()
複製代碼
# Building the model
m = 0
c = 0
L = 0.0001 # The learning Rate
epochs = 1000 # The number of iterations to perform gradient descent
n = float(len(X)) # Number of elements in X
# Performing Gradient Descent
for i in range(epochs):
Y_pred = m*X + c # The current predicted value of Y
D_m = (-2/n) * sum(X * (Y - Y_pred)) # Derivative wrt m
D_c = (-2/n) * sum(Y - Y_pred) # Derivative wrt c
m = m - L * D_m # Update m
c = c - L * D_c # Update c
print (m, c)
複製代碼
輸出 1.4796491688889395 0.10148121494753726
# Making predictions
Y_pred = m*X + c
plt.scatter(X, Y)
plt.plot([min(X), max(X)], [min(Y_pred), max(Y_pred)], color='red') # regression line
plt.show()
複製代碼
梯度降低是機器學習中最簡單、應用最普遍的算法之一,主要是由於它能夠應用到任何函數去優化它。學習它爲掌握機器學習奠基基礎。
數據集和代碼在這裏: 02 Linear Regression using Gradient Descent