特徵數據標準化指的是對訓練樣本經過利用每一列的統計量將特徵列轉換爲0均值單位方差的數據。 這是很是通用的數據預處理步驟。
例如:RBF核的支持向量機或者基於L1和L2正則化的線性模型在數據標準化以後效果會更好。
數據標準化可以改進優化過程當中數據收斂的速度,也能防止一些方差過大的變量特徵對模型訓練 產生過大的影響。
如何對數據標準化呢?公式也很是簡單:新的列 = (老的列每個值 - 老的列平均值) / (老的列標準差)apache
在標準化以前,Spark必須知道每一列的平均值,方差,具體怎麼知道呢?
想法很簡單,首先給 Spark的 StandardScaler 一批數據,這批數據以 org.apache.spark.mllib.feature.Vector 的形式提供給 StandardScaler。StandardScaler 對輸入的數據進行 fit 即計算每一列的平均值,方差。 調度代碼以下:微信
import org.apache.spark.SparkContext._ import org.apache.spark.mllib.feature.StandardScaler import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") val scaler1 = new StandardScaler().fit(data.map(x => x.features)) val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features))
上面代碼的本質是生成一個包含每一列均值和方差的 StandardScalarModel,具體解釋一下 withMean 和 withStd 的含義:機器學習
下面給出上面 fit 函數的源代碼:ide
/** * 計算數據每一列的平均值標準差,將會用於以後的標準化. * * @param data The data used to compute the mean and variance to build the transformation model. * @return a StandardScalarModel */ @Since("1.1.0") def fit(data: RDD[Vector]): StandardScalerModel = { // TODO: 若是 withMean 和 withStd 都爲false,什麼都不用幹 //計算基本統計 val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) //經過標準差,平均值獲得模型 new StandardScalerModel( Vectors.dense(summary.variance.toArray.map(v => math.sqrt(v))), summary.mean, withStd, withMean) }
從這裏能夠發現,若是你知道每一列的平均值和方差,直接經過 StandardScalarModel 構建模型就能夠了,以下代碼:函數
val scaler3 = new StandardScalerModel(scaler2.std, scaler2.mean)
準備工做作好了,下面真正標準化,調用代碼也很是簡單:學習
al data1 = data.map(x => (x.label, scaler1.transform(x.features)))
用模型對每一行 transform 就能夠了,背後的原理也很是簡單,代碼以下:大數據
// 由於 `shift` 只是在 `withMean` 爲真的分支中才使用, 因此使用了 // `lazy val`. 注意:這裏不想在每一次 `transform` 都計算一遍 shift. private lazy val shift: Array[Double] = mean.toArray /** * Applies standardization transformation on a vector. * * @param vector Vector to be standardized. * @return Standardized vector. If the std of a column is zero, it will return default `0.0` * for the column with zero std. */ @Since("1.1.0") override def transform(vector: Vector): Vector = { require(mean.size == vector.size) if (withMean) { // By default, Scala generates Java methods for member variables. So every time when // the member variables are accessed, `invokespecial` will be called which is expensive. // This can be avoid by having a local reference of `shift`. val localShift = shift vector match { case DenseVector(vs) => val values = vs.clone() val size = values.size if (withStd) { var i = 0 while (i < size) { values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0 i += 1 } } else { var i = 0 while (i < size) { values(i) -= localShift(i) i += 1 } } Vectors.dense(values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else if (withStd) { vector match { case DenseVector(vs) => val values = vs.clone() val size = values.size var i = 0 while(i < size) { values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0) i += 1 } Vectors.dense(values) case SparseVector(size, indices, vs) => // For sparse vector, the `index` array inside sparse vector object will not be changed, // so we can re-use it to save memory. val values = vs.clone() val nnz = values.size var i = 0 while (i < nnz) { values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0) i += 1 } Vectors.sparse(size, indices, values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else { // Note that it's safe since we always assume that the data in RDD should be immutable. vector } }
標準化原理簡單,代碼也簡單,可是做用不能小看。優化
歡迎關注本人微信公衆號,會定時發送關於大數據、機器學習、Java、Linux 等技術的學習文章,並且是一個系列一個系列的發佈,無任何廣告,純屬我的興趣。
ui