Spark ML使用一個自定義的Map(ParmaMap類型),其實該類內部使用了mutable.Map容器來存儲數據。html
以下所示其定義:apache
Class ParamMap private[ml] (private val map.mutable.Map[Param[Any],Any])dom |
從上述定義能夠看出,ParamMap是用一個Map來存儲,key爲Param[Any],value爲Any。這裏的value就是用戶設置的參數值,而key是對String的封裝,對用戶來所其實就是字符串。機器學習
如上述的tokenizer類,對調用setInputCol方法來設置輸入DataFrame的輸入列,其內部實現以下所示:ide
Final val inputCol:Param[String] = new Param[String](this,"inputCol","input column name") 函數 def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] 學習 final def set[T](param:Param[T],value:T):this.type={ ui set(param->value) this }spa |
Transformer類是一個抽象類,爲了實現從一個DataFrame轉換爲另外一個DataFrame,其子類只須要實現三個方法便可。以下所示的源碼:
abstract class Transformer extends PipelineStage {
/** * Transforms the dataset with optional parameters * @param dataset input dataset * @param firstParamPair the first param pair, overwrite embedded params * @param otherParamPairs other param pairs, overwrite embedded params * @return transformed dataset */ @Since("2.0.0") @varargs def transform( dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() .put(firstParamPair) .put(otherParamPairs: _*) transform(dataset, map) }
/** * Transforms the dataset with provided parameter map as additional parameters. * @param dataset input dataset * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ @Since("2.0.0") def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = { this.copy(paramMap).transform(dataset) }
/** * Transforms the input dataset. */ @Since("2.0.0") def transform(dataset: Dataset[_]): DataFrame
override def copy(extra: ParamMap): Transformer } |
HasInputCol和HasOutputCol都是接口,它們定義了一種協議。如有輸入或有輸出參數的Transformer,那麼就須要實現這個接口。
private[ml] trait HasInputCol extends Params {
final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") /** @group getParam */ final def getInputCol: String = $(inputCol) } |
private[ml] trait HasOutputCol extends Params {
final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name")
setDefault(outputCol, uid + "__output") /** @group getParam */ final def getOutputCol: String = $(outputCol) } |
這個類是一元轉換的抽象類,其以一個DataFrame列做爲輸入,而後通過處理後,產生一個新列增長到輸入的DataFrame中。
該類的源碼以下所示:
abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging {
/** API method*/ def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]
/** API method */ def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T]
/** * Creates the transform function using the given param map. The input param map already takes * account of the embedded param map. So the param values should be determined solely by the input * param map. */ protected def createTransformFunc: IN => OUT
/** * Returns the data type of the output column. */ protected def outputDataType: DataType
/** * Validates the input type. Throw an exception if it is invalid. */ protected def validateInputType(inputType: DataType): Unit = {}
override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains($(outputCol))) { throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.") } val outputFields = schema.fields :+ StructField($(outputCol), outputDataType, nullable = false) StructType(outputFields) } /** API method */ override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val transformUDF = udf(this.createTransformFunc, outputDataType) dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) }
override def copy(extra: ParamMap): T = defaultCopy(extra) } |
該類提供三個API方法,用戶經過使用這些方法來實現轉換功能,以下所示:
Method |
Description |
setInputCol |
指明輸入DataFrame中的哪一列是被處理的,輸入參數是Dataframe中存在的列名 |
setOutputCol |
設置新增長列的名字,及對輸入的列變換後悔產生一個新列,該方法設置增長新列的列名 |
transform |
用戶經過調用該方法實現DataFrame的轉換,其實調用該方法是在原來的DataFrame中增長了一個新列,如何增長一個新列,則由createTransformFunc方法來實現。 |
須要特別說明的是transform方法的最後一條語句,其使用了Dataset的以下方法:
Dataset.withColumn(colName:String, col:Column):DataFrame
該方法的功能是經過在遍歷dataset中的每一行,而後每行都增長一列,列名爲colName,內容爲col。
由於UnaryTransformer類是一個抽象類,其沒有指明一個輸入列如何產生一個新列,這些具體轉換工做須要子類來實現。子類須要實現三個方法:
Method |
Description |
createTransformFunc |
該函數實現瞭如何將一個輸入參數變化後產生一個新數據,便可用將其理解爲map操做,即inàout. |
outputDataType |
子類實現該方法的目的是返回一個輸出列的數據類型; |
validataInputType |
驗證輸入列的類型的合法性。 |
UnaryTransformer抽象類有7個實現類,用戶若是須要自定義轉換操做也能夠繼承該類,而後實現相應的操做便可。以下以Tokenizer類進行介紹,以下所示:
class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable { //1. 在繼承UnaryTransformer類時,指明瞭createTransformFunc函數的輸入參數類型和返回參數類型 @Since("1.2.0") def this() = this(Identifiable.randomUID("tok"))
//2. 實現了一個輸入值如何進行處理,而後將其返回 override protected def createTransformFunc: String => Seq[String] = { _.toLowerCase.split("\\s") }
//3. 驗證輸入參數類型是否合法 override protected def validateInputType(inputType: DataType): Unit = { require(inputType == StringType, s"Input type must be string type but got $inputType.") }
//4.返回DataFrame中新增長列的類型 override protected def outputDataType: DataType = new ArrayType(StringType, true)
@Since("1.4.1") override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) } |
經過上述前兩節的分析,咱們知道在模型訓練後,Estimator會生成一個Transformer對象。這種Transformer對象就是Model類的子類,其也是Transformer抽象類的子類。
Model類簇都有特別的功能,其是機器學習模型在訓練後的模型,即其可以對輸入的DataFrame進行預測,因此都特別有針對性。
Estimator就是機器學習中的模型,其在Spark ML中有不少實現子類。不一樣的學習模型都有不一樣的實現方式。經過前兩節分析,咱們瞭解到Estimator在訓練後悔產生一個Transformer,這個Transformer實際上是Model類。每種Estimator都對應有一種Model。其類圖如圖 5所示。
圖 5
由於Estimator繼承PipelineStage,因此Estimator的實現類須要實現三個方法: