Spark Mllib源碼分析

1. Param

  Spark ML使用一個自定義的Map(ParmaMap類型),其實該類內部使用了mutable.Map容器來存儲數據。html

以下所示其定義:apache

Class ParamMap private[ml] (private val map.mutable.Map[Param[Any],Any])dom

  從上述定義能夠看出,ParamMap是用一個Map來存儲,key爲Param[Any],value爲Any。這裏的value就是用戶設置的參數值,而key是對String的封裝,對用戶來所其實就是字符串。機器學習

如上述的tokenizer類,對調用setInputCol方法來設置輸入DataFrame的輸入列,其內部實現以下所示:ide

Final val inputCol:Param[String] = new Param[String](this,"inputCol","input column name") 函數

def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] 學習

final def set[T](param:Param[T],value:T):this.type={ ui

set(param->value) this

}spa

 

2. Transformer

  Transformer類是一個抽象類,爲了實現從一個DataFrame轉換爲另外一個DataFrame,其子類只須要實現三個方法便可。以下所示的源碼:

abstract class Transformer extends PipelineStage {

 

/**

* Transforms the dataset with optional parameters

* @param dataset input dataset

* @param firstParamPair the first param pair, overwrite embedded params

* @param otherParamPairs other param pairs, overwrite embedded params

* @return transformed dataset

*/

@Since("2.0.0")

@varargs

def transform(

dataset: Dataset[_],

firstParamPair: ParamPair[_],

otherParamPairs: ParamPair[_]*): DataFrame = {

val map = new ParamMap()

.put(firstParamPair)

.put(otherParamPairs: _*)

transform(dataset, map)

}

 

/**

* Transforms the dataset with provided parameter map as additional parameters.

* @param dataset input dataset

* @param paramMap additional parameters, overwrite embedded params

* @return transformed dataset

*/

@Since("2.0.0")

def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = {

this.copy(paramMap).transform(dataset)

}

 

/**

* Transforms the input dataset.

*/

@Since("2.0.0")

def transform(dataset: Dataset[_]): DataFrame

 

override def copy(extra: ParamMap): Transformer

}

 

  • transform():該方法是用戶的API方法,用戶直接調用該方法來實現轉換;
  • copy():該方法複製了一個Transformer對象;
  • transformSchema:因爲Transformer類繼承了PipelineStage接口,該接口有這個方法實現。
2.1 HasInputColHasOutputCol

  HasInputCol和HasOutputCol都是接口,它們定義了一種協議。如有輸入或有輸出參數的Transformer,那麼就須要實現這個接口。

private[ml] trait HasInputCol extends Params {

 

final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name")

/** @group getParam */

final def getInputCol: String = $(inputCol)

}

private[ml] trait HasOutputCol extends Params {

 

final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name")

 

setDefault(outputCol, uid + "__output")

/** @group getParam */

final def getOutputCol: String = $(outputCol)

}

 

2.2 UnaryTransformer

  這個類是一元轉換的抽象類,其以一個DataFrame列做爲輸入,而後通過處理後,產生一個新列增長到輸入的DataFrame中。

該類的源碼以下所示:

abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]

extends Transformer with HasInputCol with HasOutputCol with Logging {

 

/** API method*/

def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]

 

/** API method */

def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T]

 

/**

* Creates the transform function using the given param map. The input param map already takes

* account of the embedded param map. So the param values should be determined solely by the input

* param map.

*/

protected def createTransformFunc: IN => OUT

 

/**

* Returns the data type of the output column.

*/

protected def outputDataType: DataType

 

/**

* Validates the input type. Throw an exception if it is invalid.

*/

protected def validateInputType(inputType: DataType): Unit = {}

 

override def transformSchema(schema: StructType): StructType = {

val inputType = schema($(inputCol)).dataType

validateInputType(inputType)

if (schema.fieldNames.contains($(outputCol))) {

throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")

}

val outputFields = schema.fields :+

StructField($(outputCol), outputDataType, nullable = false)

StructType(outputFields)

}

/** API method */

override def transform(dataset: Dataset[_]): DataFrame = {

transformSchema(dataset.schema, logging = true)

val transformUDF = udf(this.createTransformFunc, outputDataType)

dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))

}

 

override def copy(extra: ParamMap): T = defaultCopy(extra)

}

1) API method

  該類提供三個API方法,用戶經過使用這些方法來實現轉換功能,以下所示:

Method

Description

setInputCol

指明輸入DataFrame中的哪一列是被處理的,輸入參數是Dataframe中存在的列名

setOutputCol

設置新增長列的名字,及對輸入的列變換後悔產生一個新列,該方法設置增長新列的列名

transform

用戶經過調用該方法實現DataFrame的轉換,其實調用該方法是在原來的DataFrame中增長了一個新列,如何增長一個新列,則由createTransformFunc方法來實現。

須要特別說明的是transform方法的最後一條語句,其使用了Dataset的以下方法:

    Dataset.withColumn(colName:String, col:Column):DataFrame

該方法的功能是經過在遍歷dataset中的每一行,而後每行都增長一列,列名爲colName,內容爲col。

2) Implement method

  由於UnaryTransformer類是一個抽象類,其沒有指明一個輸入列如何產生一個新列,這些具體轉換工做須要子類來實現。子類須要實現三個方法:

Method

Description

createTransformFunc

該函數實現瞭如何將一個輸入參數變化後產生一個新數據,便可用將其理解爲map操做,即inàout.

outputDataType

子類實現該方法的目的是返回一個輸出列的數據類型;

validataInputType

驗證輸入列的類型的合法性。

 

    UnaryTransformer抽象類有7個實現類,用戶若是須要自定義轉換操做也能夠繼承該類,而後實現相應的操做便可。以下以Tokenizer類進行介紹,以下所示:

class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)

extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable {

//1. 在繼承UnaryTransformer類時,指明瞭createTransformFunc函數的輸入參數類型和返回參數類型

@Since("1.2.0")

def this() = this(Identifiable.randomUID("tok"))

 

//2. 實現了一個輸入值如何進行處理,而後將其返回

override protected def createTransformFunc: String => Seq[String] = {

_.toLowerCase.split("\\s")

}

 

//3. 驗證輸入參數類型是否合法

override protected def validateInputType(inputType: DataType): Unit = {

require(inputType == StringType, s"Input type must be string type but got $inputType.")

}

 

//4.返回DataFrame中新增長列的類型

override protected def outputDataType: DataType = new ArrayType(StringType, true)

 

@Since("1.4.1")

override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)

}

2.3 Model

  經過上述前兩節的分析,咱們知道在模型訓練後,Estimator會生成一個Transformer對象。這種Transformer對象就是Model類的子類,其也是Transformer抽象類的子類。

Model類簇都有特別的功能,其是機器學習模型在訓練後的模型,即其可以對輸入的DataFrame進行預測,因此都特別有針對性。

3. Estimator

  Estimator就是機器學習中的模型,其在Spark ML中有不少實現子類。不一樣的學習模型都有不一樣的實現方式。經過前兩節分析,咱們瞭解到Estimator在訓練後悔產生一個Transformer,這個Transformer實際上是Model類。每種Estimator都對應有一種Model。其類圖如圖 5所示。

圖 5

由於Estimator繼承PipelineStage,因此Estimator的實現類須要實現三個方法:

  • copy(extra:ParamMap):實現模型拷貝操做;
  • transformSchema(schema:StructType):實現DataFrame結構的轉換;
  • fit(dataset:Dataset[_]):實現模型訓練,這個很是重要,是用戶的API方法,該方法會返回一個Model實現類。

4. 參考文獻

[1]. Spark MLlib
相關文章
相關標籤/搜索