環境
spark-1.6
python3.5html
1、邏輯迴歸
邏輯迴歸又叫logistic迴歸分析,是一種廣義的線性迴歸分析模型。線性迴歸要求因變量必須是連續性的數據變量,邏輯迴歸要求因變量必須是分類變量,能夠是二分類或者多分類(多分類均可以歸結到二分類問題),邏輯迴歸的輸出是0~1之間的機率。好比要分析年齡,性別,身高,飲食習慣對於體重的影響,若是體重是實際的重量,那麼就要使用線性迴歸。若是將體重分類,分紅了高,中,低三類,就要使用邏輯迴歸進行分類。
(1)邏輯迴歸公式:, 其中,e是天然對數,無限不循環小數,
,
java
邏輯迴歸公式又叫邏輯函數(Logistic function)或者S形函數(Sigmoid function)。邏輯迴歸公式的圖像以下:
即當z=0時, =0.5,當z
時, 趨近於1,當
時, 趨近於0。node
邏輯迴歸的輸出 就是位於0~1之間的機率,假設如今判斷病人是否生病,獲得的z=2對應的
=0.7,咱們能夠歸結爲生病,若是z=-2對應的
=0.1咱們就能夠認爲不生病。當z=0時,
=0.5是決策的邊界。python
(2)假設 都大於零,那麼當z=0時,
=0.5。也就是
時,
=0.5,當z=0時使用圖像來表達兩個自變量的關係爲:redis
圖中A、B、C、D、E點都表示有 ,
兩個維度的數據,現要使用邏輯迴歸對五個點分紅兩類:I和II類:sql
A點和B點都在直線上,對應的z值是0,那麼邏輯迴歸結果 =0.5,將A,B兩點劃分爲I類和II類均可以,假設咱們規定當z>=0屬於I類,z<0屬於II類,那麼A,B屬於I類。apache
C點位於直線的上方,對應的z值要大於零,反映到S形函數上對應的 >0.5屬於I類,同理,E也屬於I類,D點屬於II類。json
圖中E點的z值遠大於C點的z值,反映到S形函數中,E點屬於I類的機率比C點屬於I類的機率要大的多。api
訓練邏輯迴歸模型,在這裏就是訓練出一條直線將兩個類別的點隔開。若是維度是3,那麼訓練邏輯迴歸模型就是訓練一個平面將兩個類別的點隔開。若是維度大於3,那麼訓練邏輯迴歸模型就是訓練一個超平面將兩個類別的點隔開。數組
2、案例:音樂分類(使用python的sklearn包)
一、概念
時域分析:
對一個信號來講,信號強度隨時間的變化的規律就是時域特性,例如一個信號的時域波形能夠表達信號隨着時間的變化。
頻域分析:
對一個信號來講,在對其進行分析時,分析信號和頻率有關的部分,而不是和時間相關的部分,和時域相對。也就是信號是由哪些單一頻率的的信號合成的就是頻域特性。頻域中有一個重要的規則是正弦波是頻域中惟一存在的波。即正弦波是對頻域的描述,由於時域中的任何波形均可用正弦波合成。
傅里葉變換:
通常來講,時域的表示較爲形象直觀,頻域分析則簡練。傅里葉變換是貫穿時域和頻域的方法之一,傅里葉變換就是將難以處理的時域信號轉換成了易於分析的頻域信號。
傅里葉原理:任何連續測量的時序信號,均可以表示爲不一樣頻率的正弦波信號的無限疊加。
二、音樂分類的步驟:
(1)經過傅里葉變換將不一樣7類裏面全部原始wav格式音樂文件轉換爲特徵,並取前1000個特徵,存入文件以便後續訓練使用
(2)讀入以上7類特徵向量數據做爲訓練集
(3)使用sklearn包中LogisticRegression的fit方法計算出分類模型
(4)讀入黑豹樂隊歌曲」無地自容」並進行傅里葉變換一樣取前1000維做爲特徵向量
(5)調用模型的predict方法對音樂進行分類,結果分爲rock即搖滾類
首先來看一下單個音樂文件的頻譜圖:
# -*- coding:utf-8 -*- from scipy import fft from scipy.io import wavfile from matplotlib.pyplot import specgram import matplotlib.pyplot as plt # 能夠先把一個wav文件讀入python,而後繪製它的頻譜圖(spectrogram)來看看是什麼樣的 #畫框設置 #figsize=(10, 4)寬度和高度的英寸 # dpi=80 分辨率 # plt.figure(figsize=(10, 4),dpi=80) # # (sample_rate, X) = wavfile.read("E:/genres/metal/converted/metal.00065.au.wav") # print sample_rate, X.shape # specgram(X, Fs=sample_rate, xextent=(0,30)) # plt.xlabel("time") # plt.ylabel("frequency") ##線的形狀和顏色 # plt.grid(True, linestyle='-', color='0.75') ##tight緊湊一點 # plt.savefig("E:/metal.00065.au.wav5.png", bbox_inches="tight") # 固然,咱們也能夠把每一種的音樂都抽一些出來打印頻譜圖以便比較,以下圖: # def plotSpec(g,n): # sample_rate, X = wavfile.read("E:/genres/"+g+"/converted/"+g+"."+n+".au.wav") # specgram(X, Fs=sample_rate, xextent=(0,30)) # plt.title(g+"_"+n[-1]) # # plt.figure(num=None, figsize=(18, 9), dpi=80, facecolor='w', edgecolor='k') # plt.subplot(6,3,1);plotSpec("classical","00001");plt.subplot(6,3,2);plotSpec("classical","00002") # plt.subplot(6,3,3);plotSpec("classical","00003");plt.subplot(6,3,4);plotSpec("jazz","00001") # plt.subplot(6,3,5);plotSpec("jazz","00002");plt.subplot(6,3,6);plotSpec("jazz","00003") # plt.subplot(6,3,7);plotSpec("country","00001");plt.subplot(6,3,8);plotSpec("country","00002") # plt.subplot(6,3,9);plotSpec("country","00003");plt.subplot(6,3,10);plotSpec("pop","00001") # plt.subplot(6,3,11);plotSpec("pop","00002");plt.subplot(6,3,12);plotSpec("pop","00003") # plt.subplot(6,3,13);plotSpec("rock","00001");plt.subplot(6,3,14);plotSpec("rock","00002") # plt.subplot(6,3,15);plotSpec("rock","00003");plt.subplot(6,3,16);plotSpec("metal","00001") # plt.subplot(6,3,17);plotSpec("metal","00002");plt.subplot(6,3,18);plotSpec("metal","00003") # plt.tight_layout(pad=0.4, w_pad=0, h_pad=1.0) # plt.savefig("D:/compare.au.wav.png", bbox_inches="tight") # 對單首音樂進行傅里葉變換 #畫框設置figsize=(9, 6)寬度和高度的英寸,dpi=80是分辨率 plt.figure(figsize=(9, 6), dpi=80) #sample_rate表明每秒樣本的採樣率,X表明讀取文件的全部信息 音軌信息,這裏全是單音軌數據 是個數組【雙音軌是個二維數組,左聲道和右聲道】 #採樣率:每秒從連續信號中提取並組成離散信號的採樣個數,它用赫茲(Hz)來表示 sample_rate, X = wavfile.read("../../data/genres/jazz/converted/jazz.00002.au.wav") print(sample_rate,X,type(X),len(X)) plt.subplot(211) #畫wav文件時頻分析的函數 specgram(X, Fs=sample_rate) plt.xlabel("time") plt.ylabel("frequency") plt.subplot(212) #fft 快速傅里葉變換 fft(X)獲得振幅 即當前採樣下頻率的振幅 fft_X = abs(fft(X)) print("fft_x",fft_X,len(fft_X)) #畫頻域分析圖 注意Python3裏要求把NFFT、noverlap、Fs默認參數寫上 可能會報錯 python2不用 specgram(fft_X,NFFT=256,noverlap=128,Fs=2) plt.xlabel("frequency") plt.ylabel("amplitude") plt.savefig("../../data/jazz.00000.au.wav.fft.png") plt.show()
結果:
# 22050 [ 110 161 124 ... 1865 1683 1248] <class 'numpy.ndarray'> 661794 # fft_x [40496. 42167.35454671 31547.4214156 ... 45633.12023717 31547.4214156 42167.35454671] 661794
訓練模型:
# -*- coding:utf-8 -*- """ 使用logistic regression處理音樂數據,音樂數據訓練樣本的得到和使用快速傅里葉變換(FFT)預處理的方法須要事先準備好 1. 把訓練集擴大到每類100個首歌,類別仍然是六類:jazz,classical,country, pop, rock, metal 2. 同時使用logistic迴歸訓練模型 3. 引入一些評價的標準來比較Logistic測試集上的表現 """ from scipy import fft from scipy.io import wavfile import numpy as np # 準備音樂數據 def create_fft(g,n): rad="../../data/genres/"+g+"/converted/"+g+"."+str(n).zfill(5)+".au.wav" #sample_rate 音頻的採樣率,X表明讀取文件的全部信息 (sample_rate, X) = wavfile.read(rad) print(sample_rate) #取1000個頻率特徵 也就是振幅 fft_features = abs(fft(X)[:1000]) #zfill(5) 字符串不足5位,前面補0 sad="../../data/trainset/"+g+"."+str(n).zfill(5)+ ".fft" np.save(sad, fft_features) #-------create fft 構建訓練集-------------- genre_list = ["classical", "jazz", "country", "pop", "rock", "metal","hiphop"] for g in genre_list: for n in range(100): create_fft(g,n) print('running...') print('finished')
使用訓練特徵對音樂文件分類:
# -*- coding:utf-8 -*- from scipy import fft from scipy.io import wavfile from sklearn.linear_model import LogisticRegression import numpy as np #========================================================================================= # 加載訓練集數據,分割訓練集以及測試集,進行分類器的訓練 # 構造訓練集!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #-------read fft-------------- genre_list = ["classical", "jazz", "country", "pop", "rock", "metal","hiphop"] X=[] Y=[] for g in genre_list: for n in range(100): rad="../../data/trainset/"+g+"."+str(n).zfill(5)+ ".fft"+".npy" #加載文件 fft_features = np.load(rad) X.append(fft_features) #genre_list.index(g) 返回匹配上類別的索引號 Y.append(genre_list.index(g)) #構建的訓練集 X=np.array(X) #構建的訓練集對應的類別 Y=np.array(Y) # 接下來,咱們使用sklearn,來構造和訓練咱們的兩種分類器 #------train logistic classifier-------------- #建立LogisticRegression 要制定solver 不然會告警:FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to model = LogisticRegression(solver='liblinear',multi_class='ovr',max_iter=1000) #須要numpy.array類型參數 model.fit(X, Y) print('Starting read wavfile...') #prepare test data------------------- # sample_rate, test = wavfile.read("i:/classical.00007.au.wav") sample_rate, test = wavfile.read("../../data/heibao-wudizirong-remix.wav") print(sample_rate,test) testdata_fft_features = abs(fft(test))[:1000] #model.predict(testdata_fft_features) 預測爲一個數組,array([類別]) #testdata_fft_features若是不使用reshape(1, -1)處理一下 可能報錯ValueError: Expected 2D array, got 1D array instead: #新版本sklearn中全部東西都必須是一個2D矩陣,即便是一個簡單的column或row) 使用array.reshape(-1, 1)從新調整你的數據 testdata=testdata_fft_features.reshape(1, -1) # print(testdata) type_index = model.predict(testdata)[0] print(type_index) print(genre_list[type_index])
結果:
Starting read wavfile... 44100 [0 0 0 ... 2 2 0] 0 classical
3、案例:道路擁堵預測(使用spark的MLLIB)
一、道路擁堵訓練集
每條道路的擁堵狀況不只和當前道路前一個時間點擁堵狀況有關係,還和與這條道路臨近的其餘道路的擁堵狀況有關。甚至還和昨天當前時間點當前道路是否擁堵有關聯。咱們能夠根據這個規律,構建訓練集,預測一條道路擁堵狀況。
假設如今要訓練一個模型:使用某條道路最近三分鐘擁堵的狀況,預測該條道路下一分鐘的擁堵狀況。如何構建訓練集?
特色:構建的訓練集有什麼樣的特色,依靠訓練集訓練的模型就具有什麼樣的功能。
二、步驟:
(1)計算道路每分鐘通過的車輛數和速度總和,能夠獲得道路實時擁堵狀況
數據收集:每一個路口的攝像頭獲取數據,將數據發送到kafka裏,而後使用sparkstreaming接收並計算每分鐘數據,將數據保存到Redis裏
發送數據:
package com.ic.traffic.streaming import java.sql.Timestamp import java.util.Properties import kafka.javaapi.producer.Producer import kafka.producer.{KeyedMessage, ProducerConfig} import org.apache.spark.{SparkContext, SparkConf} import org.codehaus.jettison.json.JSONObject import scala.util.Random //向kafka car_events中生產數據 object KafkaEventProducer { def main(args: Array[String]): Unit = { val topic = "car_events" val brokers = "node1:9092,node2:9092,node3:9092" val props = new Properties() props.put("metadata.broker.list", brokers) props.put("serializer.class", "kafka.serializer.StringEncoder") val kafkaConfig = new ProducerConfig(props) val producer = new Producer[String, String](kafkaConfig) val sparkConf = new SparkConf().setAppName("traffic data").setMaster("local[4]") val sc = new SparkContext(sparkConf) val filePath = "./data/2014082013_all_column_test.txt" val records = sc.textFile(filePath) .filter(!_.startsWith(";")) .map(_.split(",")).collect() for (i <- 1 to 100) { for (record <- records) { // prepare event data val event = new JSONObject() event.put("camera_id", record(0)) .put("car_id", record(2)) .put("event_time", record(4)) .put("speed", record(6)) .put("road_id", record(13)) // produce event message producer.send(new KeyedMessage[String, String](topic,event.toString)) println("Message sent: " + event) Thread.sleep(200) } } sc.stop } }
接收、彙總、保存數據:
package com.ic.traffic.streaming import java.text.SimpleDateFormat import java.util.Calendar import kafka.serializer.StringDecoder import net.sf.json.JSONObject import org.apache.spark.SparkConf import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.kafka._ import org.apache.spark.streaming.dstream.InputDStream /** * 將每一個卡扣的總速度_車輛數 存入redis中 * 【yyyyMMdd_Monitor_id,HHmm,SpeedTotal_CarCount】 */ object CarEventCountAnalytics { def main(args: Array[String]): Unit = { // Create a StreamingContext with the given master URL val conf = new SparkConf().setAppName("CarEventCountAnalytics") if (args.length == 0) { conf.setMaster("local[*]") } val ssc = new StreamingContext(conf, Seconds(5)) // ssc.checkpoint(".") // Kafka configurations val topics = Set("car_events") val brokers = "node1:9092,node2:9092,node3:9092" val kafkaParams = Map[String, String]( "metadata.broker.list" -> brokers, "serializer.class" -> "kafka.serializer.StringEncoder") val dbIndex = 1 // Create a direct stream val kafkaStream: InputDStream[(String, String)] = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](ssc, kafkaParams, topics) val events: DStream[JSONObject] = kafkaStream.map(line => { //JSONObject.fromObject 將string 轉換成jsonObject val data = JSONObject.fromObject(line._2) println(data) data }) /** * carSpeed K:monitor_id * V:(speedCount,carCount) */ val carSpeed = events.map(jb => (jb.getString("camera_id"),jb.getInt("speed"))) .mapValues((speed:Int)=>(speed,1)) //(camera_id, (speed, 1) ) => (camera_id , (total_speed , total_count)) .reduceByKeyAndWindow((a:Tuple2[Int,Int], b:Tuple2[Int,Int]) => {(a._1 + b._1, a._2 + b._2)},Seconds(60),Seconds(10)) // .reduceByKeyAndWindow((a:Tuple2[Int,Int], b:Tuple2[Int,Int]) => {(a._1 + b._1, a._2 + b._2)},(a:Tuple2[Int,Int], b:Tuple2[Int,Int]) => {(a._1 - b._1, a._2 - b._2)},Seconds(20),Seconds(10)) carSpeed.foreachRDD(rdd => { rdd.foreachPartition(partitionOfRecords => { val jedis = RedisClient.pool.getResource partitionOfRecords.foreach(pair => { val camera_id = pair._1 val speedTotal = pair._2._1 val CarCount = pair._2._2 val now = Calendar.getInstance().getTime() // create the date/time formatters val minuteFormat = new SimpleDateFormat("HHmm") val dayFormat = new SimpleDateFormat("yyyyMMdd") val time = minuteFormat.format(now) val day = dayFormat.format(now) if(CarCount!=0){ jedis.select(dbIndex) jedis.hset(day + "_" + camera_id, time , speedTotal + "_" + CarCount) } }) RedisClient.pool.returnResource(jedis) }) }) println("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") ssc.start() ssc.awaitTermination() } }
(2)預測道路的擁堵狀況受當前道路附近道路擁堵的狀況,受這幾個道路過去幾分鐘道路擁堵的狀況。預測道路擁堵狀況能夠根據附近每條道路和當前道路前3分鐘道路擁堵的狀況來預測。用附近每條道路和當前道路前3分鐘道路的擁堵狀況來當作維度。統計這些道路過去5個小時內每分鐘的前3分鐘擁堵狀況構建數據集。
(3)訓練邏輯迴歸模型
(4)保存模型
package com.ic.traffic.streaming import java.text.SimpleDateFormat import java.util import java.util.{Date} import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import scala.collection.mutable.ArrayBuffer import scala.Array import scala.collection.mutable.ArrayBuffer import org.apache.spark.mllib.classification.LogisticRegressionModel /** * 訓練模型 */ object TrainLRwithLBFGS { val sparkConf = new SparkConf().setAppName("train traffic model").setMaster("local[*]") val sc = new SparkContext(sparkConf) // create the date/time formatters val dayFormat = new SimpleDateFormat("yyyyMMdd") val minuteFormat = new SimpleDateFormat("HHmm") def main(args: Array[String]) { // fetch data from redis val jedis = RedisClient.pool.getResource jedis.select(1) // find relative road monitors for specified road // val camera_ids = List("310999003001","310999003102","310999000106","310999000205","310999007204") val camera_ids = List("310999003001","310999003102") val camera_relations:Map[String,Array[String]] = Map[String,Array[String]]( "310999003001" -> Array("310999003001","310999003102","310999000106","310999000205","310999007204"), "310999003102" -> Array("310999003001","310999003102","310999000106","310999000205","310999007204") ) val temp = camera_ids.map({ camera_id => val hours = 5 val nowtimelong = System.currentTimeMillis(); val now = new Date(nowtimelong) val day = dayFormat.format(now)//yyyyMMdd val array = camera_relations.get(camera_id).get /** * relations中存儲了每個卡扣在day這一天每一分鐘的平均速度 */ val relations = array.map({ camera_id => // println(camera_id) // fetch records of one camera for three hours ago val minute_speed_car_map = jedis.hgetAll(day + "_'" + camera_id+"'") (camera_id, minute_speed_car_map) }) // relations.foreach(println) // organize above records per minute to train data set format (MLUtils.loadLibSVMFile) val dataSet = ArrayBuffer[LabeledPoint]() // start begin at index 3 //Range 從300到1 遞減 不包含0 for(i <- Range(60*hours,0,-1)){ val features = ArrayBuffer[Double]() val labels = ArrayBuffer[Double]() // get current minute and recent two minutes for(index <- 0 to 2){ //當前時刻過去的時間那一分鐘 val tempOne = nowtimelong - 60 * 1000 * (i-index) val d = new Date(tempOne) val tempMinute = minuteFormat.format(d)//HHmm //下一分鐘 val tempNext = tempOne - 60 * 1000 * (-1) val dNext = new Date(tempNext) val tempMinuteNext = minuteFormat.format(dNext)//HHmm for((k,v) <- relations){ val map = v //map -- k:HHmm v:Speed if(index == 2 && k == camera_id){ if (map.containsKey(tempMinuteNext)) { val info = map.get(tempMinuteNext).split("_") val f = info(0).toFloat / info(1).toFloat labels += f } } if (map.containsKey(tempMinute)){ val info = map.get(tempMinute).split("_") val f = info(0).toFloat / info(1).toFloat features += f } else{ features += -1.0 } } } if(labels.toArray.length == 1 ){ //array.head 返回數組第一個元素 val label = (labels.toArray).head val record = LabeledPoint(if ((label.toInt/10)<10) (label.toInt/10) else 10.0, Vectors.dense(features.toArray)) dataSet += record } } // dataSet.foreach(println) // println(dataSet.length) val data = sc.parallelize(dataSet) // Split data into training (80%) and test (20%). //將data這個RDD隨機分紅 8:2兩個RDD val splits = data.randomSplit(Array(0.8, 0.2)) //構建訓練集 val training = splits(0) /** * 測試集的重要性: * 測試模型的準確度,防止模型出現過擬合的問題 */ val test = splits(1) if(!data.isEmpty()){ // 訓練邏輯迴歸模型 val model = new LogisticRegressionWithLBFGS() .setNumClasses(11) .setIntercept(true) .run(training) // 測試集測試模型 val predictionAndLabels = test.map { case LabeledPoint(label, features) => val prediction = model.predict(features) (prediction, label) } predictionAndLabels.foreach(x=> println("預測類別:"+x._1+",真實類別:"+x._2)) // Get evaluation metrics. 獲得評價指標 val metrics: MulticlassMetrics = new MulticlassMetrics(predictionAndLabels) val precision = metrics.precision// 準確率 println("Precision = " + precision) if(precision > 0.8){ val path = "hdfs://node1:9000/model/model_"+camera_id+"_"+nowtimelong model.save(sc, path) println("saved model to "+ path) jedis.hset("model", camera_id , path) } } }) RedisClient.pool.returnResource(jedis) } }
(5)使用模型預測道路的擁堵狀況
package com.ic.traffic.streaming import java.text.SimpleDateFormat import java.util.Date import org.apache.spark.mllib.classification.{ LogisticRegressionModel, LogisticRegressionWithLBFGS } import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.{ SparkConf, SparkContext } import scala.collection.mutable.ArrayBuffer object PredictLRwithLBFGS { val sparkConf = new SparkConf().setAppName("predict traffic").setMaster("local[4]") val sc = new SparkContext(sparkConf) // create the date/time formatters val dayFormat = new SimpleDateFormat("yyyyMMdd") val minuteFormat = new SimpleDateFormat("HHmm") val sdf = new SimpleDateFormat("yyyy-MM-dd_HH:mm:ss") def main(args: Array[String]) { val input = "2019-05-20_17:25:00" val date = sdf.parse(input) val inputTimeLong = date.getTime() // val inputTime = new Date(inputTimeLong) val day = dayFormat.format(date)//yyyyMMdd // fetch data from redis val jedis = RedisClient.pool.getResource jedis.select(1) // find relative road monitors for specified road // val camera_ids = List("310999003001","310999003102","310999000106","310999000205","310999007204") val camera_ids = List("310999003001", "310999003102") val camera_relations: Map[String, Array[String]] = Map[String, Array[String]]( "310999003001" -> Array("310999003001", "310999003102", "310999000106", "310999000205", "310999007204"), "310999003102" -> Array("310999003001", "310999003102", "310999000106", "310999000205", "310999007204")) val temp = camera_ids.map({ camera_id => val list = camera_relations.get(camera_id).get val relations = list.map({ camera_id => // fetch records of one camera for three hours ago (camera_id, jedis.hgetAll(day + "_'" + camera_id + "'")) }) // relations.foreach(println) // organize above records per minute to train data set format (MLUtils.loadLibSVMFile) val aaa = ArrayBuffer[Double]() // get current minute and recent two minutes for (index <- 3 to (1,-1)) { //拿到過去 一分鐘,兩分鐘,過去三分鐘的時間戳 val tempOne = inputTimeLong - 60 * 1000 * index val currentOneTime = new Date(tempOne) //獲取輸入時間的 "HHmm" val tempMinute = minuteFormat.format(currentOneTime) println("inputtime ====="+currentOneTime) for ((k, v) <- relations) { // k->camera_id ; v->speed val map = v if (map.containsKey(tempMinute)) { val info = map.get(tempMinute).split("_") val f = info(0).toFloat / info(1).toFloat aaa += f } else { aaa += -1.0 } } } // Run training algorithm to build the model val path = jedis.hget("model", camera_id) if(path!=null){ val model = LogisticRegressionModel.load(sc, path) // Compute raw scores on the test set. val prediction = model.predict(Vectors.dense(aaa.toArray)) println(input + "\t" + camera_id + "\t" + prediction + "\t") // jedis.hset(input, camera_id, prediction.toString) } }) RedisClient.pool.returnResource(jedis) } }
注意:提升模型的分類數,會提升模型的抗干擾能力。好比道路擁堵狀況就分爲兩類:「暢通」、「擁堵」,若是模型針對一條原本屬於「暢通」分類的數據預測錯了,那麼預測結果只能就是「擁堵」,那麼就發生了質的改變。若是咱們將道路擁堵狀況分爲四類:「暢通」,「比較暢通」,「比較擁堵」,「擁堵」。若是模型針對一條原本數據「暢通」分類的數據預測錯了,那麼預測結果錯的狀況下就不是隻有「擁堵」這個狀況,有多是其餘三類的一種,也有必定的機率預測分類爲「比較暢通」,那麼就至關於提升了模型的抗干擾能力。
關於label和feature:
label是分類,你要預測的東西,而feature則是特徵(好比你經過一些特徵黃色,圓,得出是月亮)。若是你訓練出feature和label的關係,以後你能夠經過feature得出label。
參考: