最近在用Spark MLlib進行特徵處理時,對於StringIndexer和IndexToString遇到了點問題,查閱官方文檔也沒有解決疑惑。無奈之下翻看源碼才明白其中一二...這就給你們娓娓道來。html
更多內容參考個人大數據學習之路java
StringIndexer能夠把字符串的列按照出現頻率進行排序,出現次數最高的對應的Index爲0。好比下面的列表進行StringIndexergit
id | category |
---|---|
0 | a |
1 | b |
2 | c |
3 | a |
4 | a |
5 | c |
就能夠獲得以下:github
id | category | categoryIndex |
---|---|---|
0 | a | 0.0 |
1 | b | 2.0 |
2 | c | 1.0 |
3 | a | 0.0 |
4 | a | 0.0 |
5 | c | 1.0 |
能夠看到出現次數最多的"a",索引爲0;次數最少的"b"索引爲2。sql
針對訓練集中沒有出現的字符串值,spark提供了幾種處理的方法:apache
下面是基於Spark MLlib 2.2.0的代碼樣例:數組
package xingoo.ml.features.tranformer import org.apache.spark.sql.SparkSession import org.apache.spark.ml.feature.StringIndexer object StringIndexerTest { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[*]").appName("string-indexer").getOrCreate() spark.sparkContext.setLogLevel("WARN") val df = spark.createDataFrame( Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) ).toDF("id", "category") val df1 = spark.createDataFrame( Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "e"), (5, "f")) ).toDF("id", "category") val indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex") .setHandleInvalid("keep") //skip keep error val model = indexer.fit(df) val indexed = model.transform(df1) indexed.show(false) } }
獲得的結果爲:app
+---+--------+-------------+ |id |category|categoryIndex| +---+--------+-------------+ |0 |a |0.0 | |1 |b |2.0 | |2 |c |1.0 | |3 |a |0.0 | |4 |e |3.0 | |5 |f |3.0 | +---+--------+-------------+
這個索引轉回字符串要搭配前面的StringIndexer一塊兒使用才行:ide
package xingoo.ml.features.tranformer import org.apache.spark.ml.attribute.Attribute import org.apache.spark.ml.feature.{IndexToString, StringIndexer} import org.apache.spark.sql.SparkSession object IndexToString2 { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate() spark.sparkContext.setLogLevel("WARN") val df = spark.createDataFrame(Seq( (0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c") )).toDF("id", "category") val indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex") .fit(df) val indexed = indexer.transform(df) println(s"Transformed string column '${indexer.getInputCol}' " + s"to indexed column '${indexer.getOutputCol}'") indexed.show() val inputColSchema = indexed.schema(indexer.getOutputCol) println(s"StringIndexer will store labels in output column metadata: " + s"${Attribute.fromStructField(inputColSchema).toString}\n") val converter = new IndexToString() .setInputCol("categoryIndex") .setOutputCol("originalCategory") val converted = converter.transform(indexed) println(s"Transformed indexed column '${converter.getInputCol}' back to original string " + s"column '${converter.getOutputCol}' using labels in metadata") converted.select("id", "categoryIndex", "originalCategory").show() } }
獲得的結果以下:學習
Transformed string column 'category' to indexed column 'categoryIndex' +---+--------+-------------+ | id|category|categoryIndex| +---+--------+-------------+ | 0| a| 0.0| | 1| b| 2.0| | 2| c| 1.0| | 3| a| 0.0| | 4| a| 0.0| | 5| c| 1.0| +---+--------+-------------+ StringIndexer will store labels in output column metadata: {"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"} Transformed indexed column 'categoryIndex' back to original string column 'originalCategory' using labels in metadata +---+-------------+----------------+ | id|categoryIndex|originalCategory| +---+-------------+----------------+ | 0| 0.0| a| | 1| 2.0| b| | 2| 1.0| c| | 3| 0.0| a| | 4| 0.0| a| | 5| 1.0| c| +---+-------------+----------------+
假如處理的過程很複雜,從新生成了一個DataFrame,此時想要把這個DataFrame基於IndexToString轉回原來的字符串怎麼辦呢? 先來試試看:
package xingoo.ml.features.tranformer import org.apache.spark.ml.feature.{IndexToString, StringIndexer} import org.apache.spark.sql.SparkSession object IndexToString3 { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate() spark.sparkContext.setLogLevel("WARN") val df = spark.createDataFrame(Seq( (0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c") )).toDF("id", "category") val df2 = spark.createDataFrame(Seq( (0, 2.0), (1, 1.0), (2, 1.0), (3, 0.0) )).toDF("id", "index") val indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex") .fit(df) val indexed = indexer.transform(df) val converter = new IndexToString() .setInputCol("categoryIndex") .setOutputCol("originalCategory") val converted = converter.transform(df2) converted.show() } }
運行後發現異常:
18/07/05 20:20:32 INFO StateStoreCoordinatorRef: Registered StateStoreCoordinator endpoint Exception in thread "main" java.lang.IllegalArgumentException: Field "categoryIndex" does not exist. at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266) at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266) at scala.collection.MapLike$class.getOrElse(MapLike.scala:128) at scala.collection.AbstractMap.getOrElse(Map.scala:59) at org.apache.spark.sql.types.StructType.apply(StructType.scala:265) at org.apache.spark.ml.feature.IndexToString.transformSchema(StringIndexer.scala:338) at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:74) at org.apache.spark.ml.feature.IndexToString.transform(StringIndexer.scala:352) at xingoo.ml.features.tranformer.IndexToString3$.main(IndexToString3.scala:37) at xingoo.ml.features.tranformer.IndexToString3.main(IndexToString3.scala)
這是爲何呢?跟隨源碼來看吧!
首先咱們建立一個DataFrame,得到原始數據:
val df = spark.createDataFrame(Seq( (0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c") )).toDF("id", "category")
而後建立對應的StringIndexer:
val indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex") .setHandleInvalid("skip") .fit(df)
這裏面的fit就是在訓練轉換器了,進入fit():
override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) // 這裏針對須要轉換的列先強制轉換成字符串,而後遍歷統計每一個字符串出現的次數 val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) .countByValue() // counts是一個map,裏面的內容爲{a->3, b->1, c->2} val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray // 按照個數大小排序,返回數組,[a, c, b] // 把這個label保存起來,並返回對應的model(mllib裏邊的模型都是這個套路,跟sklearn學的) copyValues(new StringIndexerModel(uid, labels).setParent(this)) }
這樣就獲得了一個列表,列表裏面的內容是[a, c, b],而後執行transform來進行轉換:
val indexed = indexer.transform(df)
這個transform可想而知就是用這個數組對每一行的該列進行轉換,可是它其實還作了其餘的事情:
override def transform(dataset: Dataset[_]): DataFrame = { ... // -------- // 經過label生成一個Metadata,這個很關鍵!!! // metadata實際上是一個map,內容爲: // {"ml_attr":{"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}} // -------- val metadata = NominalAttribute.defaultAttr .withName($(outputCol)).withValues(filteredLabels).toMetadata() // 若是是skip則過濾一些數據 ... // 下面是針對不一樣的狀況處理轉換的列,邏輯很簡單 val indexer = udf { label: String => ... if (labelToIndex.contains(label)) { labelToIndex(label) //若是正常,就進行轉換 } else if (keepInvalid) { labels.length // 若是是keep,就返回索引的最大值(即數組的長度) } else { ... // 若是是error,就拋出異常 } } // 保留以前全部的列,新增一個字段,並設置字段的StructField中的Metadata!!!! // 並設置字段的StructField中的Metadata!!!! // 並設置字段的StructField中的Metadata!!!! // 並設置字段的StructField中的Metadata!!!! filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) }
看到了嗎!關鍵的地方在這裏,給新增長的字段的類型StructField設置了一個Metadata。這個Metadata正常都是空的{}
,可是這裏設置了metadata以後,裏面包含了label數組的信息。
接下來看看IndexToString是怎麼用的,因爲IndexToString是一個Transformer,所以只有一個trasform方法:
override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata // 關鍵是這裏: // 若是IndexToString設置了labels數組,就直接返回; // 不然,就讀取了傳入的DataFrame的StructField中的Metadata val values = if (!isDefined(labels) || $(labels).isEmpty) { Attribute.fromStructField(inputColSchema) .asInstanceOf[NominalAttribute].values.get } else { $(labels) } // 基於這個values把index轉成對應的值 val indexer = udf { index: Double => val idx = index.toInt if (0 <= idx && idx < values.length) { values(idx) } else { throw new SparkException(s"Unseen index: $index ??") } } val outputColName = $(outputCol) dataset.select(col("*"), indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) }
瞭解StringIndexer和IndexToString的原理機制後,就能夠做出以下的應對策略了。
val df2 = spark.createDataFrame(Seq( (0, 2.0), (1, 1.0), (2, 1.0), (3, 0.0) )).toDF("id", "index").select(col("*"),col("index").as("formated_index", indexed.schema("categoryIndex").metadata)) val converter = new IndexToString() .setInputCol("formated_index") .setOutputCol("origin_col") val converted = converter.transform(df2) converted.show(false)
+---+-----+--------------+----------+ |id |index|formated_index|origin_col| +---+-----+--------------+----------+ |0 |2.0 |2.0 |b | |1 |1.0 |1.0 |c | |2 |1.0 |1.0 |c | |3 |0.0 |0.0 |a | +---+-----+--------------+----------+
val df3 = spark.createDataFrame(Seq( (0, 2.0), (1, 1.0), (2, 1.0), (3, 0.0) )).toDF("id", "index") val converter2 = new IndexToString() .setInputCol("index") .setOutputCol("origin_col") .setLabels(indexed.schema("categoryIndex").metadata.getMetadata("ml_attr").getStringArray("vals")) val converted2 = converter2.transform(df3) converted2.show(false)
+---+-----+----------+ |id |index|origin_col| +---+-----+----------+ |0 |2.0 |b | |1 |1.0 |c | |2 |1.0 |c | |3 |0.0 |a | +---+-----+----------+
兩種方法都能獲得正確的輸出。
完整的代碼能夠參考github連接:
最終仍是推薦詳細閱讀官方文檔,不過官方文檔真心有些粗糙,想要了解其中的原理,仍是得靜下心來看看源碼。