[Scala] NDCG 的 Scala 實現

1、關於 NDCG

[LTR] 信息檢索評價指標(RP/MAP/DCG/NDCG/RR/ERR)

2、代碼實現

一、訓練數據的加載解析

import scala.io.Source

/*
* 訓練行數據
* */
case class TrainDataRow(target: Int, qid: Int, features: Array[Double])

object TrainDataRow {
  // 加載文件數據
  // 格式:
  // <line> .=. <target> qid:<qid> <feature>:<value> <feature>:<value> ... <feature>:<value> # <info>
  // <target> .=. <positive integer>
  // <qid> .=. <positive integer>
  // <feature> .=. <positive integer>
  // <value> .=. <float>
  // <info> .=. <string>
  def loadFile(file: String): List[TrainDataRow] = {
    Source.fromFile(file).getLines.toList.par.map(x => {
      val strArray = x.split(' ')
      val label = strArray(0).toInt
      val qid = strArray(1).split(':')(1).toInt
      val fValArray = strArray.drop(2).map(x => x.split(':')(1).toDouble)
      new TrainDataRow(label, qid, fValArray)
    }).toList
  }
}

二、NDCG 的實現

object NDCG {
  /*
  * 計算 NDCG 分值
  * */
  def score(rows: List[TrainDataRow], k: Int): Double = {
    val size = k.min(rows.length - 1)
    // 理想 DCG
    var idealDcg: Double = 0
    val sortedList = rows.sortWith((x, y) => x.target > y.target)
    for (i <- 0 to size) {
      // 計算累計效益
      val gain = (1 << sortedList(i).target) - 1
      // 計算折扣因子
      val discount = 1.0 / (Math.log(i + 2) / Math.log(2))
      idealDcg += gain * discount
    }
    if (idealDcg > 0) {
      var dcg: Double = 0
      for (i <- 0 to size) {
        // 計算累計效益
        val gain = (1 << rows(i).target) - 1
        // 計算折扣因子
        val discount = 1.0 / (Math.log(i + 2) / Math.log(2))
        dcg += gain * discount
      }
      dcg / idealDcg
    }
    else 0
  }
}

三、訓練數據集的 NDCG 計算

def calcNDCG(trainDataFile: String, k: Int): Double = {
  println("開始計算...")
  val start = System.nanoTime()
  val data = TrainDataRow.loadFile(trainDataFile) // 加載訓練數據文件
  println("數據量:" + data.length + ",用時:" + (System.nanoTime() - start) / 1000000 + " ms")
  val grpData: Map[Int, List[TrainDataRow]] = data.groupBy(_.qid) // 根據 qid 分組
  val resultNDCG = grpData.map(x => NDCG.score(x._2, k)).sum / grpData.size
  println(s"NDCG@$k: $resultNDCG")
  val end = System.nanoTime()
  println("計算運行時間:" + (end - start) / 1000000 + " ms")
  resultNDCG
}

 

by. Memento
html

相關文章
相關標籤/搜索