import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} import scala.collection.mutable import scala.collection.mutable.ListBuffer object test423_cosvec { def main(args: Array[String]): Unit = { val str1 = "據說菠蘿就是鳳梨" val str2 = "鳳梨確定不會是菠蘿" val result=textCosine(str1,str2) println("兩句話的餘弦距離: "+result) } /** * 向量的模長 * @param vec */ def module(vec:Vector[Double]): Double ={ // math.sqrt( vec.map(x=>x*x).sum ) math.sqrt(vec.map(math.pow(_,2)).sum) } /** * 求兩個向量的內積 * @param v1 * @param v2 */ def innerProduct(v1:Vector[Double],v2:Vector[Double]): Double ={ val listBuffer=ListBuffer[Double]() for(i<- 0 until v1.length; j<- 0 until v2.length;if i==j){ if(i==j){ listBuffer.append( v1(i)*v2(j) ) } } listBuffer.sum } /** * 求兩個向量的餘弦值 * @param v1 * @param v2 */ def cosvec(v1:Vector[Double],v2:Vector[Double]):Double ={ val cos=innerProduct(v1,v2) / (module(v1)* module(v2)) if (cos <= 1) cos else 1.0 } def textCosine(str1:String,str2:String):Double={ val set=mutable.Set[Char]() //統計兩句話全部的字 str1.foreach(set +=_) str2.foreach(set +=_) println(set) val ints1: Vector[Double] = set.toList.sorted.map(ch => { str1.count(s => s == ch).toDouble }).toVector println("===ints1: "+ints1) val ints2: Vector[Double] = set.toList.sorted.map(ch => { str2.count(s => s == ch).toDouble }).toVector println("===ints2: "+ints2) cosvec(ints1,ints2) } }