ALS簡介
ALS是alternating least squares的縮寫 , 意爲交替最小二乘法;而ALS-WR是alternating-least-squares with weighted-λ -regularization的縮寫,意爲加權正則化交替最小二乘法。該方法經常使用於基於矩陣分解的推薦系統中。例如:將用戶(user)對商品(item)的評分矩陣分解爲兩個矩陣:一個是用戶對商品隱含特徵的偏好矩陣,另外一個是商品所包含的隱含特徵的矩陣。在這個矩陣分解的過程當中,評分缺失項獲得了填充,也就是說咱們能夠基於這個填充的評分來給用戶最商品推薦了。
ALS is the abbreviation of squares alternating least, meaning the alternating least squares method; and the ALS-WR is alternating-least-squares with weighted- lambda -regularization acronym, meaning weighted regularized alternating least squares method. This method is often used in recommender systems based on matrix factorization. For example, the user (user) score matrix of item is decomposed into two matrices: one is the user preference matrix for the implicit features of the commodity, and the other is the matrix of the implied features of the commodity. In the process of decomposing the matrix, the score missing is filled, that is, we can give the user the most recommended commodity based on the filled score.java
ALS-WR算法,簡單地說就是:
(數據格式爲:userId, itemId, rating, timestamp )
1 對每一個userId隨機初始化N(10)個factor值,由這些值影響userId的權重。
2 對每一個itemId也隨機初始化N(10)個factor值。
3 固定userId,從userFactors矩陣和rating矩陣中分解出itemFactors矩陣。即[Item Factors Matrix] = [User Factors Matrix]^-1 * [Rating Matrix].
4 固定itemId,從itemFactors矩陣和rating矩陣中分解出userFactors矩陣。即[User Factors Matrix] = [Item Factors Matrix]^-1 * [Rating Matrix].
5 重複迭代第3,第4步,最後能夠收斂到穩定的userFactors和itemFactors。
6 對itemId進行推斷就爲userFactors * itemId = rating value;對userId進行推斷就爲itemFactors * userId = rating value。git
Spark支持ML和MLLIB兩種機器學習庫,官方推薦的是ML, 由於ML功能更全面更靈活,將來會主要支持ML。github
ML實現ALS推薦:
/** * @author huangyueran * @category ALS-WR */ public class JavaALSExampleByMl { private static final Logger log = LoggerFactory.getLogger(JavaALSExampleByMl.class); public static class Rating implements Serializable { // 0::2::3::1424380312 private int userId; // 0 private int movieId; // 2 private float rating; // 3 private long timestamp; // 1424380312 public Rating() { } public Rating(int userId, int movieId, float rating, long timestamp) { this.userId = userId; this.movieId = movieId; this.rating = rating; this.timestamp = timestamp; } public int getUserId() { return userId; } public int getMovieId() { return movieId; } public float getRating() { return rating; } public long getTimestamp() { return timestamp; } public static Rating parseRating(String str) { String[] fields = str.split("::"); if (fields.length != 4) { throw new IllegalArgumentException("Each line must contain 4 fields"); } int userId = Integer.parseInt(fields[0]); int movieId = Integer.parseInt(fields[1]); float rating = Float.parseFloat(fields[2]); long timestamp = Long.parseLong(fields[3]); return new Rating(userId, movieId, rating, timestamp); } } public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaALSExample").setMaster("local"); JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(jsc); JavaRDD<Rating> ratingsRDD = jsc.textFile("data/sample_movielens_ratings.txt") .map(new Function<String, Rating>() { public Rating call(String str) { return Rating.parseRating(str); } }); Dataset<Row> ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class); Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2}); // //對數據進行分割,80%爲訓練樣例,剩下的爲測試樣例。 Dataset<Row> training = splits[0]; Dataset<Row> test = splits[1]; // Build the recommendation model using ALS on the training data ALS als = new ALS().setMaxIter(5) // 設置迭代次數 .setRegParam(0.01) // //正則化參數,使每次迭代平滑一些,此數據集取0.1好像錯誤率低一些。 .setUserCol("userId").setItemCol("movieId") .setRatingCol("rating"); ALSModel model = als.fit(training); // //調用算法開始訓練 Dataset<Row> itemFactors = model.itemFactors(); itemFactors.show(1500); Dataset<Row> userFactors = model.userFactors(); userFactors.show(); // Evaluate the model by computing the RMSE on the test data Dataset<Row> rawPredictions = model.transform(test); //對測試數據進行預測 Dataset<Row> predictions = rawPredictions .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType)) .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType)); RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating") .setPredictionCol("prediction"); Double rmse = evaluator.evaluate(predictions); log.info("Root-mean-square error = {} ", rmse); jsc.stop(); } }
MLLIB實現ALS推薦:
/** * @category ALS * @author huangyueran * */ public class JavaALSExampleByMlLib { private static final Logger log = LoggerFactory.getLogger(JavaALSExampleByMlLib.class); public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaALSExample").setMaster("local[4]"); JavaSparkContext jsc = new JavaSparkContext(conf); JavaRDD<String> data = jsc.textFile("data/sample_movielens_ratings.txt"); JavaRDD<Rating> ratings = data.map(new Function<String, Rating>() { public Rating call(String s) { String[] sarray = StringUtils.split(StringUtils.trim(s), "::"); return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2])); } }); // Build the recommendation model using ALS int rank = 10; int numIterations = 6; MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); // Evaluate the model on rating data JavaRDD<Tuple2<Object, Object>> userProducts = ratings.map(new Function<Rating, Tuple2<Object, Object>>() { public Tuple2<Object, Object> call(Rating r) { return new Tuple2<Object, Object>(r.user(), r.product()); } }); // 預測的評分 JavaPairRDD<Tuple2<Integer, Integer>, Double> predictions = JavaPairRDD .fromJavaRDD(model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD() .map(new Function<Rating, Tuple2<Tuple2<Integer, Integer>, Double>>() { public Tuple2<Tuple2<Integer, Integer>, Double> call(Rating r) { return new Tuple2<Tuple2<Integer, Integer>, Double>( new Tuple2<Integer, Integer>(r.user(), r.product()), r.rating()); } })); JavaPairRDD<Tuple2<Integer, Integer>, Tuple2<Double, Double>> ratesAndPreds = JavaPairRDD .fromJavaRDD(ratings.map(new Function<Rating, Tuple2<Tuple2<Integer, Integer>, Double>>() { public Tuple2<Tuple2<Integer, Integer>, Double> call(Rating r) { return new Tuple2<Tuple2<Integer, Integer>, Double>( new Tuple2<Integer, Integer>(r.user(), r.product()), r.rating()); } })).join(predictions); // 獲得按照用戶ID排序後的評分列表 key:用戶id JavaPairRDD<Integer, Tuple2<Integer, Double>> fromJavaRDD = JavaPairRDD.fromJavaRDD(ratesAndPreds.map( new Function<Tuple2<Tuple2<Integer, Integer>, Tuple2<Double, Double>>, Tuple2<Integer, Tuple2<Integer, Double>>>() { public Tuple2<Integer, Tuple2<Integer, Double>> call( Tuple2<Tuple2<Integer, Integer>, Tuple2<Double, Double>> t) throws Exception { return new Tuple2<Integer, Tuple2<Integer, Double>>(t._1._1, new Tuple2<Integer, Double>(t._1._2, t._2._2)); } })).sortByKey(false); // List<Tuple2<Integer,Tuple2<Integer,Double>>> list = fromJavaRDD.collect(); // for(Tuple2<Integer,Tuple2<Integer,Double>> t:list){ // System.out.println(t._1+":"+t._2._1+"===="+t._2._2); // } JavaRDD<Tuple2<Double, Double>> ratesAndPredsValues = ratesAndPreds.values(); double MSE = JavaDoubleRDD.fromRDD(ratesAndPredsValues.map(new Function<Tuple2<Double, Double>, Object>() { public Object call(Tuple2<Double, Double> pair) { Double err = pair._1() - pair._2(); return err * err; } }).rdd()).mean(); try { FileUtils.deleteDirectory(new File("result")); } catch (IOException e) { e.printStackTrace(); } ratesAndPreds.repartition(1).saveAsTextFile("result/ratesAndPreds"); //爲指定用戶推薦10個商品(電影) Rating[] recommendProducts = model.recommendProducts(2, 10); log.info("get recommend result:{}",Arrays.toString(recommendProducts)); // 爲全部用戶推薦TOP N個物品 //model.recommendUsersForProducts(10); // 爲全部物品推薦TOP N個用戶 //model.recommendProductsForUsers(10) model.userFeatures().saveAsTextFile("result/userFea"); model.productFeatures().saveAsTextFile("result/productFea"); log.info("Mean Squared Error = {}" , MSE); } }
以上兩種主要是經過Spark進行離線的ALS推薦。還有一種是經過Spark-Streaming流式計算,對像Kafka消息隊列中,緩衝的實時數據進行在線(實時)計算。算法
Spark-Streaming進行ALS實時推薦:
經過Spark-Streaming進行ALS推薦僅僅是其中的一環。真實項目中還涉及了不少其餘技術處理。sql
好比用戶行爲日誌數據的埋點處理,經過flume來進行監控拉取,存儲到hdfs中。經過kafka來進行海量行爲數據的消費、緩衝。apache
以及經過Spark機器學習計算後生成的訓練模型的離線存儲,Web拉取模型進行緩存,對用戶進行推薦等等。api
/** * @author huangyueran * @category 基於Spark-streaming、kafka的實時推薦模板DEMO 原系統中包含商城項目、logback、flume、hadoop * The real time recommendation template DEMO based on Spark-streaming and Kafka contains the mall project, logback, flume and Hadoop in the original system */ public final class SparkALSByStreaming { private static final Logger log = LoggerFactory.getLogger(SparkALSByStreaming.class); private static final String KAFKA_ADDR = "middleware:9092"; private static final String TOPIC = "RECOMMEND_TOPIC"; private static final String HDFS_ADDR = "hdfs://middleware:9000"; private static final String MODEL_PATH = "/spark-als/model"; // 基於Hadoop、Flume、Kafka、spark-streaming、logback、商城系統的實時推薦系統DEMO // Real time recommendation system DEMO based on Hadoop, Flume, Kafka, spark-streaming, logback and mall system // 商城系統採集的數據集格式 Data Format: // 用戶ID,商品ID,用戶行爲評分,時間戳 // UserID,ItemId,Rating,TimeStamp // 53,1286513,9,1508221762 // 53,1172348420,9,1508221762 // 53,1179495514,12,1508221762 // 53,1184890730,3,1508221762 // 53,1210793742,159,1508221762 // 53,1215837445,9,1508221762 public static void main(String[] args) { System.setProperty("HADOOP_USER_NAME", "root"); // 設置權限用戶 SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaDirectWordCount").setMaster("local[1]"); final JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(6)); Map<String, String> kafkaParams = new HashMap<String, String>(); // key是topic名稱,value是線程數量 kafkaParams.put("metadata.broker.list", KAFKA_ADDR); // 指定broker在哪 HashSet<String> topicsSet = new HashSet<String>(); topicsSet.add(TOPIC); // 指定操做的topic // Create direct kafka stream with brokers and topics // createDirectStream() JavaPairInputDStream<String, String> messages = KafkaUtils.createDirectStream(jssc, String.class, String.class, StringDecoder.class, StringDecoder.class, kafkaParams, topicsSet); JavaDStream<String> lines = messages.map(new Function<Tuple2<String, String>, String>() { public String call(Tuple2<String, String> tuple2) { return tuple2._2(); } }); JavaDStream<Rating> ratingsStream = lines.map(new Function<String, Rating>() { public Rating call(String s) { String[] sarray = StringUtils.split(StringUtils.trim(s), ","); return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2])); } }); // 進行流推薦計算 ratingsStream.foreachRDD(new VoidFunction<JavaRDD<Rating>>() { public void call(JavaRDD<Rating> ratings) throws Exception { // 獲取到原始的數據集 SparkContext sc = ratings.context(); RDD<String> textFileRDD = sc.textFile(HDFS_ADDR + "/flume/logs", 3); // 讀取原始數據集文件 JavaRDD<String> originalTextFile = textFileRDD.toJavaRDD(); final JavaRDD<Rating> originaldatas = originalTextFile.map(new Function<String, Rating>() { public Rating call(String s) { String[] sarray = StringUtils.split(StringUtils.trim(s), ","); return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2])); } }); log.info("========================================"); log.info("Original TextFile Count:{}", originalTextFile.count()); // HDFS中已經存儲的原始用戶行爲日誌數據 log.info("========================================"); // 將原始數據集和新的用戶行爲數據進行合併 JavaRDD<Rating> calculations = originaldatas.union(ratings); log.info("Calc Count:{}", calculations.count()); // Build the recommendation model using ALS int rank = 10; // 模型中隱語義因子的個數 int numIterations = 6; // 訓練次數 // 獲得訓練模型 if (!ratings.isEmpty()) { // 若是有用戶行爲數據 MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(calculations), rank, numIterations, 0.01); // 判斷文件是否存在,若是存在 刪除文件目錄 Configuration hadoopConfiguration = sc.hadoopConfiguration(); hadoopConfiguration.set("fs.defaultFS", HDFS_ADDR); FileSystem fs = FileSystem.get(hadoopConfiguration); Path outpath = new Path(MODEL_PATH); if (fs.exists(outpath)) { log.info("########### 刪除" + outpath.getName() + " ###########"); fs.delete(outpath, true); } // 保存model model.save(sc, HDFS_ADDR + MODEL_PATH); // 讀取model MatrixFactorizationModel modelLoad = MatrixFactorizationModel.load(sc, HDFS_ADDR + MODEL_PATH); // 爲指定用戶推薦10個商品(電影) for(int userId=0;userId<30;userId++){ // streaming_sample_movielens_ratings.txt Rating[] recommendProducts = modelLoad.recommendProducts(userId, 10); log.info("get recommend result:{}", Arrays.toString(recommendProducts)); } } } }); // ========================================================================================== jssc.start(); try { jssc.awaitTermination(); } catch (InterruptedException e) { e.printStackTrace(); } // Local Model try { Thread.sleep(10000000); } catch (InterruptedException e) { e.printStackTrace(); } // jssc.stop(); // jssc.close(); } }
用戶行爲數據集
商城系統採集的數據集格式 Data Format:
用戶ID,商品ID,用戶行爲評分,時間戳
UserID,ItemId,Rating,TimeStamp
53,1286513,9,1508221762
53,1172348420,9,1508221762
53,1179495514,12,1508221762
53,1184890730,3,1508221762
53,1210793742,159,1508221762
53,1215837445,9,1508221762緩存
maven依賴
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core_2.10 --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.10</artifactId> <version>2.2.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib_2.10 --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-mllib_2.10</artifactId> <version>2.2.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-sql_2.10 --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_2.10</artifactId> <version>2.2.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-streaming_2.10 --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming_2.10</artifactId> <version>2.2.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-streaming-kafka_2.10 --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming-kafka_2.10</artifactId> <version>1.6.3</version> </dependency> <!-- https://mvnrepository.com/artifact/log4j/log4j --> <dependency> <groupId>log4j</groupId> <artifactId>log4j</artifactId> <version>1.2.17</version> </dependency> <dependency> <groupId>org.slf4j</groupId> <artifactId>slf4j-api</artifactId> <version>1.7.12</version> </dependency> <dependency> <groupId>org.slf4j</groupId> <artifactId>slf4j-log4j12</artifactId> <version>1.7.12</version> </dependency>
以上代碼以及數據集能夠去Github上的項目找到dom