上接 YFCC 100M數據集分析筆記 和 使用百度地圖api可視化聚類結果, 在對 YFCC 100M 聚類出的景點信息的基礎上,使用 Spark MLlib 提供的 ALS 算法構建推薦模型。html
本節代碼可見:https://github.com/libaoquan95/TRS/tree/master/Analyse/recommendgit
數據信息:https://github.com/libaoquan95/TRS/tree/master/Analyse/datasetgithub
在用戶數據(user.csv) 和 用戶-景點數據(user-attraction.csv) 中,用戶標識和景點標識都使用了字符串進行表示,但在 Spark MLlib 提供的 ALS 算法中,要求這二者是整數類型,因此首先要對數據進行預處理,將其轉化爲整數。算法
對於 userName, 聯立 user.csv 和 user-attraction.csv,將 user-attraction.csv 中的 userName 轉化爲 userId 便可。api
對於 provinceId, 能夠考慮將其編碼,provinceId 格式爲 省份標識_省內景點編號,如 HK_100 標識使用在香港拍攝的照片聚類出的第 100 個景點。dom
編碼方式很簡單,首先將 _ 前的省份標識轉化爲數字,以後與 _ 後的數字鏈接便可。
編碼與解碼代碼以下:測試
val provinceToCode = Map( "LN" -> "10", "ShanX" -> "11", "ZJ" -> "12", "CQ" -> "13", "HLJ" -> "14", "AH" -> "15", "SanX" -> "16", "SD" -> "17", "SH" -> "18", "XJ" -> "19", "HuN" -> "20", "GS" -> "21", "HeN" -> "22", "BJ" -> "23", "NMG" -> "24", "YN" -> "25", "JX" -> "26", "HuB" -> "27", "JL" -> "28", "NX" -> "29", "TJ" -> "30", "FJ" -> "31", "SC" -> "32", "TW" -> "33", "GX" -> "34", "GD" -> "35", "HeB" -> "36", "HaiN" -> "37", "Macro" -> "38", "XZ" -> "39", "GZ" -> "40", "JS" -> "41", "QH" -> "42", "HK" -> "43" ) val codeToProvince = Map( "10" -> "LN", "11" -> "ShanX", "12" -> "ZJ", "13" -> "CQ", "14" -> "HLJ", "15" -> "AH", "16" -> "SanX", "17" -> "SD", "18" -> "SH", "19" -> "XJ", "20" -> "HuN", "21" -> "GS", "22" -> "HeN", "23" -> "BJ", "24" -> "NMG", "25" -> "YN", "26" -> "JX", "27" -> "HuB", "28" -> "JL", "29" -> "NX", "30" -> "TJ", "31" -> "FJ", "32" -> "SC", "33" -> "TW", "34" -> "GX", "35" -> "GD", "36" -> "HeB", "37" -> "HaiN", "38" -> "Macro", "39" -> "XZ", "40" -> "GZ", "41" -> "JS", "42" -> "QH", "43" -> "HK" ) // 編碼 def codeing(str: String): String = { var code: String = "" val Array(province, index) = str.split('_') code = provinceToCode(province) + index code } // 解碼 def decodeing(str: String): String = { var decode: String = "" decode = codeToProvince(str(0).toString+str(1).toString) + "_" for (i <- 1 to str.length-1){ decode += str(i).toString } decode }
以後加載用戶數據 user.scv,並去除頭標題。編碼
val dataDirBase = "..\\dataset\\" val userIdToName = sc.read. textFile(dataDirBase + "user.csv"). flatMap{ line => var Array(userId, userName) = line.split(',') if(userId == "userId"){ None } else { Some((userId, userName)) } }.collect().toMap val userNameToId = sc.read. textFile(dataDirBase + "user.csv"). flatMap{ line => var Array(userId, userName) = line.split(',') if(userId == "userId"){ None } else { Some((userName, userId)) } }.collect().toMap
轉化 user-attraction 數據code
val userAttractionDF = sc.read. textFile(dataDirBase + "user-attraction.csv"). flatMap{ line => val Array(userName, attractionId, count, rating) = line.split(',') if (userName == "userName"){ None } else { Some((userNameToId(userName).toInt, codeing(attractionId).toInt, count.toInt)) } }.toDF("user", "attraction", "count").cache()
Spark MLlib ALS 算法接受 三元組矩陣數據,分別表明 用戶標識,景點標識,評分數據,其中 用戶標識,景點標識 必須是整數。orm
ALS 是 最小交替二乘 的簡稱,是使用矩陣分解算法來填補稀疏矩陣,預測評分,具體參見矩陣分解在協同過濾推薦算法中的應用
經歷過上面的步驟後,userAttractionDF 已經轉化爲適應 ALS 算法的數據。以後能夠創建推薦模型了,將數據拆分爲訓練集和測試集,使用訓練集訓練模型。具體算法以下:
val Array(trainData, cvData) = userAttractionDF.randomSplit(Array(0.9, 0.1)) val model = new ALS(). setSeed(Random.nextLong()). setImplicitPrefs(true). setRank(10). setRegParam(0.01). setAlpha(1.0). setMaxIter(5). setUserCol("user"). setItemCol("attraction"). setRatingCol("count"). setPredictionCol("prediction"). fit(trainData)
Spark MLlib ALS 一次只能對一個用戶進行推薦,代碼以下:
def recommendByUser(userId: Int, topN: Int): Array[String] = { val toRecommend = model.itemFactors. select($"id".as("attraction")). withColumn("user", lit(userId)) val topRecommendations = model.transform(toRecommend). select("attraction", "prediction"). orderBy($"prediction".desc). limit(topN) val recommends = topRecommendations.select("attraction").as[Int].collect() recommends.map(line => decodeing(line.toString)) }
推薦效果以下:
驗證推薦模型的正確率
def testRecommend(): Unit ={ val topN = 10 val users = cvData.select($"user").distinct().collect().map(u => u(0)) var hit = 0.0 var rec_count = 0.0 var test_count = 0.0 for (i <- 0 to users.length-1) { val recs = recommendByUser(users(i).toString.toInt, topN).toSet val temp = cvData.select($"attraction"). where($"user" === users(i).toString.toInt). collect().map(a => decodeing(a(0).toString)). toSet hit += recs.&(temp).size rec_count += recs.size test_count += temp.size } print ("正確率:" + (hit / rec_count)) print ("召回率:" + (hit / test_count)) }