導讀:<a href='https://www.cnblogs.com/BeanHsiang/category/1218714.html' target='_blank'>ML.NET系列文章</a>html
ML.NET已經發布了v0.2版本,新增了聚類訓練器,執行性能進一步加強。本文將介紹一種特殊的迴歸——泊松迴歸,並以NBA比賽得分預測的案例來演練。數據結構
前面的文章已提過,迴歸是用來預測連續值的,泊松迴歸是其中一種,其特殊在僅用於預測正整數,一般爲計數類的數值。泊松分佈是離散分佈,因此特徵值和標籤值應爲相同(或接近相同)時間間隔下的獨立隨機事件。性能
那麼什麼場景是符合計數,能夠適用泊松迴歸呢?舉幾個例子,好比共享單車的調度,每一處地域中心,每隔1小時都要統計借車和還車數,根據這個統計咱們就能夠預測下一個小時此處地域須要調配多少車輛才能知足須要。再好比,公司每月都有離職員工,那麼人力資源部門就能夠對月人員流失數進行計數,而後經過泊松迴歸來預測下個月的流失狀況,以便提前採起措施作好招聘計劃。lua
是否是有一點感受了,本次咱們用你們喜歡的NBA比賽得分來進行演練,由於比賽得分正好也是一種計數,也符合連續相同時間間隔(比賽時長的大致相近),比賽結果具備不肯定性,因此也是泊松迴歸大顯身手的地方,爲了易於理解,我將示範預測的是主場球隊的得分。spa
本案例數據來源Kaggle.com,內容是<a href='https://www.kaggle.com/ionaskel/nba-games-stats-from-2014-to-2018' target='_blank'>NBA Team Game Stats from 2014 to 2018</a>,這份數據集收集了最近4年的NBA比賽,格式相似以下:code
"","Team","Game","Date","Home","Opponent","WINorLOSS","TeamPoints","OpponentPoints","FieldGoals","FieldGoalsAttempted","FieldGoals.","X3PointShots","X3PointShotsAttempted","X3PointShots.","FreeThrows","FreeThrowsAttempted","FreeThrows.","OffRebounds","TotalRebounds","Assists","Steals","Blocks","Turnovers","TotalFouls","Opp.FieldGoals","Opp.FieldGoalsAttempted","Opp.FieldGoals.","Opp.3PointShots","Opp.3PointShotsAttempted","Opp.3PointShots.","Opp.FreeThrows","Opp.FreeThrowsAttempted","Opp.FreeThrows.","Opp.OffRebounds","Opp.TotalRebounds","Opp.Assists","Opp.Steals","Opp.Blocks","Opp.Turnovers","Opp.TotalFouls" "1","ATL","1",2014-10-29,"Away","TOR","L","102","109","40","80",".500","13","22",".591","9","17",".529","10","42","26","6","8","17","24","37","90",".411","8","26",".308","27","33",".818","16","48","26","13","9","9","22" "2","ATL","2",2014-11-01,"Home","IND","W","102","92","35","69",".507","7","20",".350","25","33",".758","3","37","26","10","6","12","20","31","81",".383","12","32",".375","18","21",".857","11","44","25","5","5","18","26" "3","ATL","3",2014-11-05,"Away","SAS","L","92","94","38","92",".413","8","25",".320","8","11",".727","10","37","26","14","5","13","25","31","69",".449","5","17",".294","27","38",".711","11","50","25","7","9","19","15" "4","ATL","4",2014-11-07,"Away","CHO","L","119","122","43","93",".462","13","33",".394","20","26",".769","7","38","28","8","3","19","33","48","97",".495","6","21",".286","20","27",".741","11","51","31","6","7","19","30" "5","ATL","5",2014-11-08,"Home","NYK","W","103","96","33","81",".407","9","22",".409","28","36",".778","12","41","18","10","5","8","17","40","84",".476","8","21",".381","8","11",".727","13","44","26","2","6","15","29" "6","ATL","6",2014-11-10,"Away","NYK","W","91","85","27","71",".380","10","27",".370","27","28",".964","9","38","20","7","3","15","16","36","83",".434","6","26",".231","7","12",".583","11","40","23","4","2","15","26" "7","ATL","7",2014-11-12,"Home","UTA","W","100","97","39","76",".513","9","20",".450","13","18",".722","13","46","23","8","4","18","12","43","86",".500","5","23",".217","6","12",".500","8","30","28","12","8","11","17" "8","ATL","8",2014-11-14,"Home","MIA","W","114","103","42","75",".560","11","28",".393","19","23",".826","3","36","33","10","5","13","20","35","74",".473","10","21",".476","23","25",".920","5","32","27","10","3","14","20"
各字段以下: 比賽基本信息:主場Team,比賽場次序號Game,比賽日期Date,主隊Home,客隊Opponent,主隊勝負Win or Loss。orm
比賽主客隊技術數據:Team Points,Field Goals,Field Goals Attempted,Field Goals Percentage,3 Point Shots,3 Point Shots Attempted,3 Point Shots Percentage,Free Throws,Free Throws Attempted,Free Throws Percentage,Offensive Rebounds,Total Rebounds,Assists,Steals,Blocks,Turnovers,Total Fouls。htm
這些指標反映了主客隊投籃出手次數、命中數、命中率,三分球的出手次數、命中數、命中率,罰球的出手次數、命中數、命中率,助攻,搶斷,犯規等,這些都是咱們在看NBA時常見的統計。blog
因爲只有這一份數據,爲了分別用於訓練、評估和預測,我將數據集按7:2:1的比例進行分割。事件
定義原始數據結構、預測數據結構,TeamPoints是主隊得分,是本次示例要預測的目標,所以定義爲標籤字段。
public class Match { [Column(ordinal: "0")] public string Id; [Column(ordinal: "1")] public string Team; [Column(ordinal: "2")] public string Game; [Column(ordinal: "3")] public string Date; [Column(ordinal: "4")] public string Home; [Column(ordinal: "5")] public string Opponent; [Column(ordinal: "6")] public string WINorLOSS; [Column(ordinal: "7", name: "Label")] public float TeamPoints; [Column(ordinal: "8")] public float OpponentPoints; [Column(ordinal: "9")] public float FieldGoals; [Column(ordinal: "10")] public float FieldGoalsAttempted; [Column(ordinal: "11")] public float FieldGoals_; [Column(ordinal: "12")] public float X3PointShots; [Column(ordinal: "13")] public float X3PointShotsAttempted; [Column(ordinal: "14")] public float X3PointShots_; [Column(ordinal: "15")] public float FreeThrows; [Column(ordinal: "16")] public float FreeThrowsAttempted; [Column(ordinal: "17")] public float FreeThrows_; [Column(ordinal: "18")] public float OffRebounds; [Column(ordinal: "19")] public float TotalRebounds; [Column(ordinal: "20")] public float Assists; [Column(ordinal: "21")] public float Steals; [Column(ordinal: "22")] public float Blocks; [Column(ordinal: "23")] public float Turnovers; [Column(ordinal: "24")] public float TotalFouls; [Column(ordinal: "25")] public float Opp_FieldGoals; [Column(ordinal: "26")] public float Opp_FieldGoalsAttempted; [Column(ordinal: "27")] public float Opp_FieldGoals_; [Column(ordinal: "28")] public float Opp_3PointShots; [Column(ordinal: "29")] public float Opp_3PointShotsAttempted; [Column(ordinal: "30")] public float Opp_3PointShots_; [Column(ordinal: "31")] public float Opp_FreeThrows; [Column(ordinal: "32")] public float Opp_FreeThrowsAttempted; [Column(ordinal: "33")] public float Opp_FreeThrows_; [Column(ordinal: "34")] public float Opp_OffRebounds; [Column(ordinal: "35")] public float Opp_TotalRebounds; [Column(ordinal: "36")] public float Opp_Assists; [Column(ordinal: "37")] public float Opp_Steals; [Column(ordinal: "38")] public float Opp_Blocks; [Column(ordinal: "39")] public float Opp_Turnovers; [Column(ordinal: "40")] public float Opp_TotalFouls; } public class MatchPrediction { [ColumnName("Score")] public float TeamPoints; }
加載數據部分
const string DATA_PATH = "data/nba.games.stats.csv"; static ICollection<Match> LoadData() { var matches = new List<Match>(); using (var sr = new StreamReader(File.OpenRead(DATA_PATH))) { sr.ReadLine(); while (!sr.EndOfStream) { var line = sr.ReadLine(); var values = line.Split(","); var match = new Match { Id = values[0].Trim('"'), Team = values[1].Trim('"'), Game = values[2].Trim('"'), Date = values[3].Trim('"'), Home = values[4].Trim('"'), Opponent = values[5].Trim('"'), WINorLOSS = values[6].Trim('"'), TeamPoints = Convert.ToSingle(values[7].Trim('"')), OpponentPoints = Convert.ToSingle(values[8].Trim('"')), FieldGoals = Convert.ToSingle(values[9].Trim('"')), FieldGoalsAttempted = Convert.ToSingle(values[10].Trim('"')), FieldGoals_ = Convert.ToSingle(values[11].Trim('"')), X3PointShots = Convert.ToSingle(values[12].Trim('"')), X3PointShotsAttempted = Convert.ToSingle(values[13].Trim('"')), X3PointShots_ = Convert.ToSingle(values[14].Trim('"')), FreeThrows = Convert.ToSingle(values[15].Trim('"')), FreeThrowsAttempted = Convert.ToSingle(values[16].Trim('"')), FreeThrows_ = Convert.ToSingle(values[17].Trim('"')), OffRebounds = Convert.ToSingle(values[18].Trim('"')), TotalRebounds = Convert.ToSingle(values[19].Trim('"')), Assists = Convert.ToSingle(values[20].Trim('"')), Steals = Convert.ToSingle(values[21].Trim('"')), Blocks = Convert.ToSingle(values[22].Trim('"')), Turnovers = Convert.ToSingle(values[23].Trim('"')), TotalFouls = Convert.ToSingle(values[24].Trim('"')), Opp_FieldGoals = Convert.ToSingle(values[25].Trim('"')), Opp_FieldGoalsAttempted = Convert.ToSingle(values[26].Trim('"')), Opp_FieldGoals_ = Convert.ToSingle(values[27].Trim('"')), Opp_3PointShots = Convert.ToSingle(values[28].Trim('"')), Opp_3PointShotsAttempted = Convert.ToSingle(values[29].Trim('"')), Opp_3PointShots_ = Convert.ToSingle(values[30].Trim('"')), Opp_FreeThrows = Convert.ToSingle(values[31].Trim('"')), Opp_FreeThrowsAttempted = Convert.ToSingle(values[32].Trim('"')), Opp_FreeThrows_ = Convert.ToSingle(values[33].Trim('"')), Opp_OffRebounds = Convert.ToSingle(values[34].Trim('"')), Opp_TotalRebounds = Convert.ToSingle(values[35].Trim('"')), Opp_Assists = Convert.ToSingle(values[36].Trim('"')), Opp_Steals = Convert.ToSingle(values[37].Trim('"')), Opp_Blocks = Convert.ToSingle(values[38].Trim('"')), Opp_Turnovers = Convert.ToSingle(values[39].Trim('"')), Opp_TotalFouls = Convert.ToSingle(values[40].Trim('"')) }; matches.Add(match); } } return matches; }
訓練、評估、預測部分
static PredictionModel<Match, MatchPrediction> Train(IEnumerable<Match> trainData) { var pipeline = new LearningPipeline(); pipeline.Add(CollectionDataSource.Create(trainData)); pipeline.Add(new ColumnDropper() { Column = new[] { "Id" } }); pipeline.Add(new CategoricalOneHotVectorizer("Team", "Game", "Date", "Home", "Opponent", "WINorLOSS")); pipeline.Add(new ColumnConcatenator("Features", "Team", "Game", "Date", "Home", "Opponent", "WINorLOSS", "OpponentPoints", "FieldGoals", "FieldGoalsAttempted", "FieldGoals_", "X3PointShots", "X3PointShotsAttempted", "X3PointShots_", "FreeThrows", "FreeThrowsAttempted", "FreeThrows_", "OffRebounds", "TotalRebounds", "Assists", "Steals", "Blocks", "Turnovers", "TotalFouls", "Opp_FieldGoals", "Opp_FieldGoalsAttempted", "Opp_FieldGoals_", "Opp_3PointShots", "Opp_3PointShotsAttempted", "Opp_3PointShots_", "Opp_FreeThrows", "Opp_FreeThrowsAttempted", "Opp_FreeThrows_", "Opp_OffRebounds", "Opp_TotalRebounds", "Opp_Assists", "Opp_Steals", "Opp_Blocks", "Opp_Turnovers", "Opp_TotalFouls")); pipeline.Add(new PoissonRegressor()); var model = pipeline.Train<Match, MatchPrediction>(); return model; } static void Evaluate(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> evaluateData) { var evaluator = new RegressionEvaluator(); var metric = evaluator.Evaluate(model, CollectionDataSource.Create(evaluateData)); Console.WriteLine("LossFn: {0}", metric.LossFn); Console.WriteLine("RSquared: {0}", metric.RSquared); Console.WriteLine("Rms: {0}", metric.Rms); } static void Predict(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> predictData) { var predicts = model.Predict(predictData); var results = predictData.Zip(predicts, (d, p) => (d, p)); foreach (var result in results) { Console.WriteLine("Date: {0}, Team: {1} Opponent: {2}, Score: {3}-{4}, Predict Home Score: {5}", result.d.Date, result.d.Team, result.d.Opponent, result.d.TeamPoints, result.d.OpponentPoints, result.p.TeamPoints); } }
最後是Main調用部分
static void Main(string[] args) { var data = LoadData(); var trainCount = Convert.ToInt32(data.Count * 0.7); var evaluateCount = Convert.ToInt32(data.Count * 0.2); var trainData = data.Take(trainCount); var evaluateData = data.Skip(trainCount).Take(evaluateCount); var predictData = data.Skip(trainCount + evaluateCount); var model = Train(trainData); Evaluate(model, evaluateData); Predict(model, predictData); }
執行結果
能夠看到,最近的NBA比賽主隊預測得分與真實結果對比,正確率已至關可觀了,因爲特徵值都是比賽技術數據,用在之後的比賽時,可根據比賽進行的實時狀況不斷更新,即可愈來愈接近結果。 對球迷來講這但是一件神器呀。想一想2018世界盃也立刻要開始了,保羅、阿喀琉斯什麼的都弱爆了,相信小夥伴們也要嘗試一下ML.NET的套路了吧,記得拿到歷年完整的數據喲!
完整代碼以下:
using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Models; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.IO; using System.Linq; namespace NBAPrediction { class Program { const string DATA_PATH = "data/nba.games.stats.csv"; static ICollection<Match> LoadData() { var matches = new List<Match>(); using (var sr = new StreamReader(File.OpenRead(DATA_PATH))) { sr.ReadLine(); while (!sr.EndOfStream) { var line = sr.ReadLine(); var values = line.Split(","); var match = new Match { Id = values[0].Trim('"'), Team = values[1].Trim('"'), Game = values[2].Trim('"'), Date = values[3].Trim('"'), Home = values[4].Trim('"'), Opponent = values[5].Trim('"'), WINorLOSS = values[6].Trim('"'), TeamPoints = Convert.ToSingle(values[7].Trim('"')), OpponentPoints = Convert.ToSingle(values[8].Trim('"')), FieldGoals = Convert.ToSingle(values[9].Trim('"')), FieldGoalsAttempted = Convert.ToSingle(values[10].Trim('"')), FieldGoals_ = Convert.ToSingle(values[11].Trim('"')), X3PointShots = Convert.ToSingle(values[12].Trim('"')), X3PointShotsAttempted = Convert.ToSingle(values[13].Trim('"')), X3PointShots_ = Convert.ToSingle(values[14].Trim('"')), FreeThrows = Convert.ToSingle(values[15].Trim('"')), FreeThrowsAttempted = Convert.ToSingle(values[16].Trim('"')), FreeThrows_ = Convert.ToSingle(values[17].Trim('"')), OffRebounds = Convert.ToSingle(values[18].Trim('"')), TotalRebounds = Convert.ToSingle(values[19].Trim('"')), Assists = Convert.ToSingle(values[20].Trim('"')), Steals = Convert.ToSingle(values[21].Trim('"')), Blocks = Convert.ToSingle(values[22].Trim('"')), Turnovers = Convert.ToSingle(values[23].Trim('"')), TotalFouls = Convert.ToSingle(values[24].Trim('"')), Opp_FieldGoals = Convert.ToSingle(values[25].Trim('"')), Opp_FieldGoalsAttempted = Convert.ToSingle(values[26].Trim('"')), Opp_FieldGoals_ = Convert.ToSingle(values[27].Trim('"')), Opp_3PointShots = Convert.ToSingle(values[28].Trim('"')), Opp_3PointShotsAttempted = Convert.ToSingle(values[29].Trim('"')), Opp_3PointShots_ = Convert.ToSingle(values[30].Trim('"')), Opp_FreeThrows = Convert.ToSingle(values[31].Trim('"')), Opp_FreeThrowsAttempted = Convert.ToSingle(values[32].Trim('"')), Opp_FreeThrows_ = Convert.ToSingle(values[33].Trim('"')), Opp_OffRebounds = Convert.ToSingle(values[34].Trim('"')), Opp_TotalRebounds = Convert.ToSingle(values[35].Trim('"')), Opp_Assists = Convert.ToSingle(values[36].Trim('"')), Opp_Steals = Convert.ToSingle(values[37].Trim('"')), Opp_Blocks = Convert.ToSingle(values[38].Trim('"')), Opp_Turnovers = Convert.ToSingle(values[39].Trim('"')), Opp_TotalFouls = Convert.ToSingle(values[40].Trim('"')) }; matches.Add(match); } } return matches; } static PredictionModel<Match, MatchPrediction> Train(IEnumerable<Match> trainData) { var pipeline = new LearningPipeline(); pipeline.Add(CollectionDataSource.Create(trainData)); pipeline.Add(new ColumnDropper() { Column = new[] { "Id" } }); pipeline.Add(new CategoricalOneHotVectorizer("Team", "Game", "Date", "Home", "Opponent", "WINorLOSS")); pipeline.Add(new ColumnConcatenator("Features", "Team", "Game", "Date", "Home", "Opponent", "WINorLOSS", "OpponentPoints", "FieldGoals", "FieldGoalsAttempted", "FieldGoals_", "X3PointShots", "X3PointShotsAttempted", "X3PointShots_", "FreeThrows", "FreeThrowsAttempted", "FreeThrows_", "OffRebounds", "TotalRebounds", "Assists", "Steals", "Blocks", "Turnovers", "TotalFouls", "Opp_FieldGoals", "Opp_FieldGoalsAttempted", "Opp_FieldGoals_", "Opp_3PointShots", "Opp_3PointShotsAttempted", "Opp_3PointShots_", "Opp_FreeThrows", "Opp_FreeThrowsAttempted", "Opp_FreeThrows_", "Opp_OffRebounds", "Opp_TotalRebounds", "Opp_Assists", "Opp_Steals", "Opp_Blocks", "Opp_Turnovers", "Opp_TotalFouls")); pipeline.Add(new PoissonRegressor()); var model = pipeline.Train<Match, MatchPrediction>(); return model; } static void Evaluate(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> evaluateData) { var evaluator = new RegressionEvaluator(); var metric = evaluator.Evaluate(model, CollectionDataSource.Create(evaluateData)); Console.WriteLine("LossFn: {0}", metric.LossFn); Console.WriteLine("RSquared: {0}", metric.RSquared); Console.WriteLine("Rms: {0}", metric.Rms); } static void Predict(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> predictData) { var predicts = model.Predict(predictData); var results = predictData.Zip(predicts, (d, p) => (d, p)); foreach (var result in results) { Console.WriteLine("Date: {0}, Team: {1} Opponent: {2}, Score: {3}-{4}, Predict Home Score: {5}", result.d.Date, result.d.Team, result.d.Opponent, result.d.TeamPoints, result.d.OpponentPoints, result.p.TeamPoints); } } static void Main(string[] args) { var data = LoadData(); var trainCount = Convert.ToInt32(data.Count * 0.7); var evaluateCount = Convert.ToInt32(data.Count * 0.2); var trainData = data.Take(trainCount); var evaluateData = data.Skip(trainCount).Take(evaluateCount); var predictData = data.Skip(trainCount + evaluateCount); var model = Train(trainData); Evaluate(model, evaluateData); Predict(model, predictData); } } public class Match { [Column(ordinal: "0")] public string Id; [Column(ordinal: "1")] public string Team; [Column(ordinal: "2")] public string Game; [Column(ordinal: "3")] public string Date; [Column(ordinal: "4")] public string Home; [Column(ordinal: "5")] public string Opponent; [Column(ordinal: "6")] public string WINorLOSS; [Column(ordinal: "7", name: "Label")] public float TeamPoints; [Column(ordinal: "8")] public float OpponentPoints; [Column(ordinal: "9")] public float FieldGoals; [Column(ordinal: "10")] public float FieldGoalsAttempted; [Column(ordinal: "11")] public float FieldGoals_; [Column(ordinal: "12")] public float X3PointShots; [Column(ordinal: "13")] public float X3PointShotsAttempted; [Column(ordinal: "14")] public float X3PointShots_; [Column(ordinal: "15")] public float FreeThrows; [Column(ordinal: "16")] public float FreeThrowsAttempted; [Column(ordinal: "17")] public float FreeThrows_; [Column(ordinal: "18")] public float OffRebounds; [Column(ordinal: "19")] public float TotalRebounds; [Column(ordinal: "20")] public float Assists; [Column(ordinal: "21")] public float Steals; [Column(ordinal: "22")] public float Blocks; [Column(ordinal: "23")] public float Turnovers; [Column(ordinal: "24")] public float TotalFouls; [Column(ordinal: "25")] public float Opp_FieldGoals; [Column(ordinal: "26")] public float Opp_FieldGoalsAttempted; [Column(ordinal: "27")] public float Opp_FieldGoals_; [Column(ordinal: "28")] public float Opp_3PointShots; [Column(ordinal: "29")] public float Opp_3PointShotsAttempted; [Column(ordinal: "30")] public float Opp_3PointShots_; [Column(ordinal: "31")] public float Opp_FreeThrows; [Column(ordinal: "32")] public float Opp_FreeThrowsAttempted; [Column(ordinal: "33")] public float Opp_FreeThrows_; [Column(ordinal: "34")] public float Opp_OffRebounds; [Column(ordinal: "35")] public float Opp_TotalRebounds; [Column(ordinal: "36")] public float Opp_Assists; [Column(ordinal: "37")] public float Opp_Steals; [Column(ordinal: "38")] public float Opp_Blocks; [Column(ordinal: "39")] public float Opp_Turnovers; [Column(ordinal: "40")] public float Opp_TotalFouls; } public class MatchPrediction { [ColumnName("Score")] public float TeamPoints; } }