from __future__ import print_function import sys import numpy as np from pyspark.sql import SparkSession def parseVector(line): return np.array([float(x) for x in line.split(' ')]) def closestPoint(p, centers): bestIndex = 0 closest = float("+inf") for i in range(len(centers)): tempDist = np.sum((p - centers[i]) ** 2) if tempDist < closest: closest = tempDist bestIndex = i return bestIndex if __name__ == "__main__": if len(sys.argv) != 4: print("Usage: kmeans <file> <k> <convergeDist>", file=sys.stderr) sys.exit(-1) spark = SparkSession\ .builder\ .appName("PythonKMeans")\ .getOrCreate() lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) data = lines.map(parseVector).cache()
//聚類超參數K K = int(sys.argv[2])
//收斂閾值 convergeDist = float(sys.argv[3]) //初始化K箇中心點 kPoints = data.takeSample(False, K, 1) tempDist = 1.0 while tempDist > convergeDist:
// map Key: 聚類中心點 Value: (當前點, 數量1) closest = data.map( lambda p: (closestPoint(p, kPoints), (p, 1)))
// reduce Key:聚類中心點, 計算每一個聚類中心點下的分佈 pointStats = closest.reduceByKey( lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1]))
//map 計算新的中心點 newPoints = pointStats.map( lambda st: (st[0], st[1][0] / st[1][1])).collect() tempDist = sum(np.sum((kPoints[iK] - p) ** 2) for (iK, p) in newPoints) for (iK, p) in newPoints: kPoints[iK] = p print("Final centers: " + str(kPoints)) spark.stop()