Spark MLlib 之 StringIndexer、IndexToString使用說明以及源碼剖析

最近在用Spark MLlib進行特徵處理時,對於StringIndexer和IndexToString遇到了點問題,查閱官方文檔也沒有解決疑惑。無奈之下翻看源碼才明白其中一二...這就給你們娓娓道來。html

更多內容參考個人大數據學習之路java

文檔說明

StringIndexer 字符串轉索引

StringIndexer能夠把字符串的列按照出現頻率進行排序,出現次數最高的對應的Index爲0。好比下面的列表進行StringIndexergit

id category
0 a
1 b
2 c
3 a
4 a
5 c

就能夠獲得以下:github

id category categoryIndex
0 a 0.0
1 b 2.0
2 c 1.0
3 a 0.0
4 a 0.0
5 c 1.0

能夠看到出現次數最多的"a",索引爲0;次數最少的"b"索引爲2。sql

針對訓練集中沒有出現的字符串值,spark提供了幾種處理的方法:apache

  • error,直接拋出異常
  • skip,跳過該樣本數據
  • keep,使用一個新的最大索引,來表示全部未出現的值

下面是基於Spark MLlib 2.2.0的代碼樣例:數組

package xingoo.ml.features.tranformer

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.StringIndexer

object StringIndexerTest {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]").appName("string-indexer").getOrCreate()
    spark.sparkContext.setLogLevel("WARN")

    val df = spark.createDataFrame(
      Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
    ).toDF("id", "category")

    val df1 = spark.createDataFrame(
      Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "e"), (5, "f"))
    ).toDF("id", "category")

    val indexer = new StringIndexer()
      .setInputCol("category")
      .setOutputCol("categoryIndex")
      .setHandleInvalid("keep") //skip keep error

    val model = indexer.fit(df)

    val indexed = model.transform(df1)
    indexed.show(false)
  }
}

獲得的結果爲:app

+---+--------+-------------+
|id |category|categoryIndex|
+---+--------+-------------+
|0  |a       |0.0          |
|1  |b       |2.0          |
|2  |c       |1.0          |
|3  |a       |0.0          |
|4  |e       |3.0          |
|5  |f       |3.0          |
+---+--------+-------------+

IndexToString 索引轉字符串

這個索引轉回字符串要搭配前面的StringIndexer一塊兒使用才行:ide

package xingoo.ml.features.tranformer

import org.apache.spark.ml.attribute.Attribute
import org.apache.spark.ml.feature.{IndexToString, StringIndexer}
import org.apache.spark.sql.SparkSession

object IndexToString2 {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate()
    spark.sparkContext.setLogLevel("WARN")

    val df = spark.createDataFrame(Seq(
      (0, "a"),
      (1, "b"),
      (2, "c"),
      (3, "a"),
      (4, "a"),
      (5, "c")
    )).toDF("id", "category")

    val indexer = new StringIndexer()
      .setInputCol("category")
      .setOutputCol("categoryIndex")
      .fit(df)
    val indexed = indexer.transform(df)

    println(s"Transformed string column '${indexer.getInputCol}' " +
      s"to indexed column '${indexer.getOutputCol}'")
    indexed.show()

    val inputColSchema = indexed.schema(indexer.getOutputCol)
    println(s"StringIndexer will store labels in output column metadata: " +
      s"${Attribute.fromStructField(inputColSchema).toString}\n")

    val converter = new IndexToString()
      .setInputCol("categoryIndex")
      .setOutputCol("originalCategory")

    val converted = converter.transform(indexed)

    println(s"Transformed indexed column '${converter.getInputCol}' back to original string " +
      s"column '${converter.getOutputCol}' using labels in metadata")
    converted.select("id", "categoryIndex", "originalCategory").show()
  }
}

獲得的結果以下:學習

Transformed string column 'category' to indexed column 'categoryIndex'
+---+--------+-------------+
| id|category|categoryIndex|
+---+--------+-------------+
|  0|       a|          0.0|
|  1|       b|          2.0|
|  2|       c|          1.0|
|  3|       a|          0.0|
|  4|       a|          0.0|
|  5|       c|          1.0|
+---+--------+-------------+

StringIndexer will store labels in output column metadata: {"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}

Transformed indexed column 'categoryIndex' back to original string column 'originalCategory' using labels in metadata
+---+-------------+----------------+
| id|categoryIndex|originalCategory|
+---+-------------+----------------+
|  0|          0.0|               a|
|  1|          2.0|               b|
|  2|          1.0|               c|
|  3|          0.0|               a|
|  4|          0.0|               a|
|  5|          1.0|               c|
+---+-------------+----------------+

使用問題

假如處理的過程很複雜,從新生成了一個DataFrame,此時想要把這個DataFrame基於IndexToString轉回原來的字符串怎麼辦呢? 先來試試看:

package xingoo.ml.features.tranformer

import org.apache.spark.ml.feature.{IndexToString, StringIndexer}
import org.apache.spark.sql.SparkSession

object IndexToString3 {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate()
    spark.sparkContext.setLogLevel("WARN")

    val df = spark.createDataFrame(Seq(
      (0, "a"),
      (1, "b"),
      (2, "c"),
      (3, "a"),
      (4, "a"),
      (5, "c")
    )).toDF("id", "category")

    val df2 = spark.createDataFrame(Seq(
      (0, 2.0),
      (1, 1.0),
      (2, 1.0),
      (3, 0.0)
    )).toDF("id", "index")

    val indexer = new StringIndexer()
      .setInputCol("category")
      .setOutputCol("categoryIndex")
      .fit(df)
    val indexed = indexer.transform(df)

    val converter = new IndexToString()
      .setInputCol("categoryIndex")
      .setOutputCol("originalCategory")

    val converted = converter.transform(df2)
    converted.show()
  }
}

運行後發現異常:

18/07/05 20:20:32 INFO StateStoreCoordinatorRef: Registered StateStoreCoordinator endpoint
Exception in thread "main" java.lang.IllegalArgumentException: Field "categoryIndex" does not exist.
    at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
    at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
    at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)
    at scala.collection.AbstractMap.getOrElse(Map.scala:59)
    at org.apache.spark.sql.types.StructType.apply(StructType.scala:265)
    at org.apache.spark.ml.feature.IndexToString.transformSchema(StringIndexer.scala:338)
    at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:74)
    at org.apache.spark.ml.feature.IndexToString.transform(StringIndexer.scala:352)
    at xingoo.ml.features.tranformer.IndexToString3$.main(IndexToString3.scala:37)
    at xingoo.ml.features.tranformer.IndexToString3.main(IndexToString3.scala)

這是爲何呢?跟隨源碼來看吧!

源碼剖析

首先咱們建立一個DataFrame,得到原始數據:

val df = spark.createDataFrame(Seq(
      (0, "a"),
      (1, "b"),
      (2, "c"),
      (3, "a"),
      (4, "a"),
      (5, "c")
    )).toDF("id", "category")

而後建立對應的StringIndexer:

val indexer = new StringIndexer()
      .setInputCol("category")
      .setOutputCol("categoryIndex")
      .setHandleInvalid("skip")
      .fit(df)

這裏面的fit就是在訓練轉換器了,進入fit():

override def fit(dataset: Dataset[_]): StringIndexerModel = {
    transformSchema(dataset.schema, logging = true)
    // 這裏針對須要轉換的列先強制轉換成字符串,而後遍歷統計每一個字符串出現的次數
    val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
      .rdd
      .map(_.getString(0))
      .countByValue()
    // counts是一個map,裏面的內容爲{a->3, b->1, c->2}
    val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
    // 按照個數大小排序,返回數組,[a, c, b]
    // 把這個label保存起來,並返回對應的model(mllib裏邊的模型都是這個套路,跟sklearn學的)
    copyValues(new StringIndexerModel(uid, labels).setParent(this))
  }

這樣就獲得了一個列表,列表裏面的內容是[a, c, b],而後執行transform來進行轉換:

val indexed = indexer.transform(df)

這個transform可想而知就是用這個數組對每一行的該列進行轉換,可是它其實還作了其餘的事情:

override def transform(dataset: Dataset[_]): DataFrame = {
    ...
    // --------
    // 經過label生成一個Metadata,這個很關鍵!!!
    // metadata實際上是一個map,內容爲:
    // {"ml_attr":{"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}}
    // --------
    val metadata = NominalAttribute.defaultAttr
      .withName($(outputCol)).withValues(filteredLabels).toMetadata()
    
    // 若是是skip則過濾一些數據
    ...
    
    // 下面是針對不一樣的狀況處理轉換的列,邏輯很簡單
    val indexer = udf { label: String =>
      ...
      if (labelToIndex.contains(label)) {
          labelToIndex(label) //若是正常,就進行轉換
        } else if (keepInvalid) {
          labels.length // 若是是keep,就返回索引的最大值(即數組的長度)
        } else {
          ... // 若是是error,就拋出異常
        }
    }

    // 保留以前全部的列,新增一個字段,並設置字段的StructField中的Metadata!!!!
    // 並設置字段的StructField中的Metadata!!!!
    // 並設置字段的StructField中的Metadata!!!!
    // 並設置字段的StructField中的Metadata!!!!
    
    filteredDataset.select(col("*"),
      indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
  }

看到了嗎!關鍵的地方在這裏,給新增長的字段的類型StructField設置了一個Metadata。這個Metadata正常都是空的{},可是這裏設置了metadata以後,裏面包含了label數組的信息。

接下來看看IndexToString是怎麼用的,因爲IndexToString是一個Transformer,所以只有一個trasform方法:

override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    val inputColSchema = dataset.schema($(inputCol))
    
    // If the labels array is empty use column metadata
    // 關鍵是這裏:
    // 若是IndexToString設置了labels數組,就直接返回;
    // 不然,就讀取了傳入的DataFrame的StructField中的Metadata
    val values = if (!isDefined(labels) || $(labels).isEmpty) {
      Attribute.fromStructField(inputColSchema)
        .asInstanceOf[NominalAttribute].values.get
    } else {
      $(labels)
    }

    // 基於這個values把index轉成對應的值
    val indexer = udf { index: Double =>
      val idx = index.toInt
      if (0 <= idx && idx < values.length) {
        values(idx)
      } else {
        throw new SparkException(s"Unseen index: $index ??")
      }
    }
    val outputColName = $(outputCol)
    dataset.select(col("*"),
      indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
  }

瞭解StringIndexer和IndexToString的原理機制後,就能夠做出以下的應對策略了。

1 增長StructField的MetaData信息

val df2 = spark.createDataFrame(Seq(
      (0, 2.0),
      (1, 1.0),
      (2, 1.0),
      (3, 0.0)
    )).toDF("id", "index").select(col("*"),col("index").as("formated_index", indexed.schema("categoryIndex").metadata))

    val converter = new IndexToString()
      .setInputCol("formated_index")
      .setOutputCol("origin_col")

    val converted = converter.transform(df2)
    converted.show(false)
+---+-----+--------------+----------+
|id |index|formated_index|origin_col|
+---+-----+--------------+----------+
|0  |2.0  |2.0           |b         |
|1  |1.0  |1.0           |c         |
|2  |1.0  |1.0           |c         |
|3  |0.0  |0.0           |a         |
+---+-----+--------------+----------+

2 獲取以前StringIndexer後的DataFrame中的Label信息

val df3 = spark.createDataFrame(Seq(
      (0, 2.0),
      (1, 1.0),
      (2, 1.0),
      (3, 0.0)
    )).toDF("id", "index")

    val converter2 = new IndexToString()
      .setInputCol("index")
      .setOutputCol("origin_col")
      .setLabels(indexed.schema("categoryIndex").metadata.getMetadata("ml_attr").getStringArray("vals"))

    val converted2 = converter2.transform(df3)
    converted2.show(false)
+---+-----+----------+
|id |index|origin_col|
+---+-----+----------+
|0  |2.0  |b         |
|1  |1.0  |c         |
|2  |1.0  |c         |
|3  |0.0  |a         |
+---+-----+----------+

兩種方法都能獲得正確的輸出。

完整的代碼能夠參考github連接:

https://github.com/xinghalo/spark-in-action/blob/master/src/xingoo/ml/features/tranformer/IndexToStringTest.scala

最終仍是推薦詳細閱讀官方文檔,不過官方文檔真心有些粗糙,想要了解其中的原理,仍是得靜下心來看看源碼。

相關文章
相關標籤/搜索