sklearn線性迴歸實現房價預測模型

題目要求

創建房價預測模型:利用ex1data1.txt單特徵)和ex1data2.txt多特徵)中的數據,進行線性迴歸和預測。git

做散點圖可知,數據大體符合線性關係,故暫不研究其餘形式的迴歸。github

兩份數據放在最後。函數

單特徵線性迴歸

ex1data1.txt中的數據是單特徵,做一個簡單的線性迴歸便可:\(y=ax+b\)測試

根據是否分割數據,產生兩種方案:方案一,全部樣本都用來訓練和預測;方案二,一部分樣本用來訓練,一部分用來檢驗模型。spa

方案一

對ex1data1.txt中的數據進行線性迴歸,全部樣本都用來訓練和預測。3d

代碼實現以下:code

"""
    對ex1data1.txt中的數據進行線性迴歸,全部樣本都用來訓練和預測
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標籤
plt.rcParams['axes.unicode_minus'] = False  # 用來正常顯示負號

# 數據格式:城市人口,食品經銷商利潤

# 讀取數據
data = np.loadtxt('ex1data1.txt', delimiter=',')
data_X = data[:, 0]
data_y = data[:, 1]

# 訓練模型
model = LinearRegression()
model.fit(data_X.reshape([-1, 1]), data_y)

# 利用模型進行預測
y_predict = model.predict(data_X.reshape([-1, 1]))

# 結果可視化
plt.scatter(data_X, data_y, color='red')
plt.plot(data_X, y_predict, color='blue', linewidth=3)
plt.xlabel('城市人口')
plt.ylabel('食品經銷商利潤')
plt.title('線性迴歸——城市人口與食品經銷商利潤的關係')
plt.show()

# 模型參數
print(model.coef_)
print(model.intercept_)
# MSE
print(mean_squared_error(data_y, y_predict))
# R^2
print(r2_score(data_y, y_predict))

結果以下:blog

由下可知函數形式以及\(R^2\)爲0.70unicode

[1.19303364]
-3.89578087831185
8.953942751950358
0.7020315537841397

ex1data1_1.png

方案二

對ex1data1.txt中的數據進行線性迴歸,部分樣本用來訓練,部分樣本用來預測。

實現以下:

"""
    對ex1data1.txt中的數據進行線性迴歸,部分樣本用來訓練,部分樣本用來預測
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標籤
plt.rcParams['axes.unicode_minus'] = False  # 用來正常顯示負號

# 數據格式:城市人口,食品經銷商利潤

# 讀取數據
data = np.loadtxt('ex1data1.txt', delimiter=',')
data_X = data[:, 0]
data_y = data[:, 1]

# 數據分割
X_train, X_test, y_train, y_test = train_test_split(data_X, data_y)

# 訓練模型
model = LinearRegression()
model.fit(X_train.reshape([-1, 1]), y_train)

# 利用模型進行預測
y_predict = model.predict(X_test.reshape([-1, 1]))

# 結果可視化
plt.scatter(X_test, y_test, color='red')  # 測試樣本
plt.plot(X_test, y_predict, color='blue', linewidth=3)
plt.xlabel('城市人口')
plt.ylabel('食品經銷商利潤')
plt.title('線性迴歸——城市人口與食品經銷商利潤的關係')
plt.show()

# 模型參數
print(model.coef_)
print(model.intercept_)
# MSE
print(mean_squared_error(y_test, y_predict))
# R^2
print(r2_score(y_test, y_predict))

結果以下

由下可知函數形式以及\(R^2\)爲0.80

[1.21063939]
-4.195481965945055
5.994362667047617
0.8095125123727652

ex1data1_2.png

多特徵線性迴歸

ex1data2.txt中的數據是二個特徵,做一個最簡單的多元(在此爲二元)線性迴歸便可:\(y=a_1x_1+a_2x_2+b\)

對ex1data2.txt中的數據進行線性迴歸,全部樣本都用來訓練和預測。

代碼實現以下:

"""
    對ex1data2.txt中的數據進行線性迴歸,全部樣本都用來訓練和預測
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from mpl_toolkits.mplot3d import Axes3D  # 不要去掉這個import
from sklearn.metrics import mean_squared_error, r2_score
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標籤
plt.rcParams['axes.unicode_minus'] = False  # 用來正常顯示負號

# 數據格式:城市人口,房間數目,房價

# 讀取數據
data = np.loadtxt('ex1data2.txt', delimiter=',')
data_X = data[:, 0:2]
data_y = data[:, 2]

# 訓練模型
model = LinearRegression()
model.fit(data_X, data_y)

# 利用模型進行預測
y_predict = model.predict(data_X)

# 結果可視化
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(data_X[:, 0], data_X[:, 1], data_y, color='red')
ax.plot(data_X[:, 0], data_X[:, 1], y_predict, color='blue')
ax.set_xlabel('城市人口')
ax.set_ylabel('房間數目')
ax.set_zlabel('房價')
plt.title('線性迴歸——城市人口、房間數目與房價的關係')
plt.show()

# 模型參數
print(model.coef_)
print(model.intercept_)
# MSE
print(mean_squared_error(data_y, y_predict))
# R^2
print(r2_score(data_y, y_predict))

結果以下:

由下可知函數形式以及\(R^2\)爲0.73

[  139.21067402 -8738.01911233]
89597.90954279748
4086560101.205658
0.7329450180289141

ex1data2.png

兩份數據

ex1data1.txt

6.1101,17.592
5.5277,9.1302
8.5186,13.662
7.0032,11.854
5.8598,6.8233
8.3829,11.886
7.4764,4.3483
8.5781,12
6.4862,6.5987
5.0546,3.8166
5.7107,3.2522
14.164,15.505
5.734,3.1551
8.4084,7.2258
5.6407,0.71618
5.3794,3.5129
6.3654,5.3048
5.1301,0.56077
6.4296,3.6518
7.0708,5.3893
6.1891,3.1386
20.27,21.767
5.4901,4.263
6.3261,5.1875
5.5649,3.0825
18.945,22.638
12.828,13.501
10.957,7.0467
13.176,14.692
22.203,24.147
5.2524,-1.22
6.5894,5.9966
9.2482,12.134
5.8918,1.8495
8.2111,6.5426
7.9334,4.5623
8.0959,4.1164
5.6063,3.3928
12.836,10.117
6.3534,5.4974
5.4069,0.55657
6.8825,3.9115
11.708,5.3854
5.7737,2.4406
7.8247,6.7318
7.0931,1.0463
5.0702,5.1337
5.8014,1.844
11.7,8.0043
5.5416,1.0179
7.5402,6.7504
5.3077,1.8396
7.4239,4.2885
7.6031,4.9981
6.3328,1.4233
6.3589,-1.4211
6.2742,2.4756
5.6397,4.6042
9.3102,3.9624
9.4536,5.4141
8.8254,5.1694
5.1793,-0.74279
21.279,17.929
14.908,12.054
18.959,17.054
7.2182,4.8852
8.2951,5.7442
10.236,7.7754
5.4994,1.0173
20.341,20.992
10.136,6.6799
7.3345,4.0259
6.0062,1.2784
7.2259,3.3411
5.0269,-2.6807
6.5479,0.29678
7.5386,3.8845
5.0365,5.7014
10.274,6.7526
5.1077,2.0576
5.7292,0.47953
5.1884,0.20421
6.3557,0.67861
9.7687,7.5435
6.5159,5.3436
8.5172,4.2415
9.1802,6.7981
6.002,0.92695
5.5204,0.152
5.0594,2.8214
5.7077,1.8451
7.6366,4.2959
5.8707,7.2029
5.3054,1.9869
8.2934,0.14454
13.394,9.0551
5.4369,0.61705

ex1data2.txt

2104,3,399900
1600,3,329900
2400,3,369000
1416,2,232000
3000,4,539900
1985,4,299900
1534,3,314900
1427,3,198999
1380,3,212000
1494,3,242500
1940,4,239999
2000,3,347000
1890,3,329999
4478,5,699900
1268,3,259900
2300,4,449900
1320,2,299900
1236,3,199900
2609,4,499998
3031,4,599000
1767,3,252900
1888,2,255000
1604,3,242900
1962,4,259900
3890,3,573900
1100,3,249900
1458,3,464500
2526,3,469000
2200,3,475000
2637,3,299900
1839,2,349900
1000,1,169900
2040,4,314900
3137,3,579900
1811,4,285900
1437,3,249900
1239,3,229900
2132,4,345000
4215,4,549000
2162,4,287000
1664,2,368500
2238,3,329900
2567,4,314000
1200,3,299000
852,2,179900
1852,4,299900
1203,3,239500

做者:@臭鹹魚

轉載請註明出處:https://www.cnblogs.com/chouxianyu/

歡迎討論和交流!

相關文章
相關標籤/搜索