機器學習:weka中Evaluation類源碼解析及輸出AUC及交叉驗證介紹

  在機器學習分類結果的評估中,ROC曲線下的面積AOC是一個很是重要的指標。下面是調用weka類,輸出AOC的源碼:git

try {
// 1.讀入數據集

                Instances data = new Instances(
                                      new BufferedReader(
                                        new FileReader("E:\\Develop/Weka-3-6/data/contact-lenses.arff")));

                data.setClassIndex(data.numAttributes() - 1);

// 2.訓練分類器並用十字交叉驗證法來得到Evaluation對象
// 注意這裏的方法與咱們在上幾節中使用的驗證法是不一樣。
                Classifier cl = new NaiveBayes();
                Evaluation eval = new Evaluation(data);
                eval.crossValidateModel(cl, data, 10, new Random(1));

         
// 3.生成用於獲得ROC曲面和AUC值的Instances對象
       System.out.println(eval.toClassDetailsString());
            System.out.println(eval.toSummaryString());
            System.out.println(eval.toMatrixString()); }
catch (Exception e) { e.printStackTrace(); }

 

  接着說一下交叉驗證;github

  若是沒有分開訓練集和測試集,能夠使用Cross Validation方法,Evaluation中crossValidateModel方法的四個參數分別爲,第一個是分類器,第二個是在某個數據集上評價的數據集,第三個參數是交叉檢驗的次數(10是比較常見的),第四個是一個隨機數對象。dom

  注意:使用crossValidateModel時,分類器不須要先訓練,不然buildClassifier方法會初始化分類器,交叉驗證的配置結果就沒有用了。機器學習

  類crossValidateModel的源碼以下:學習

 public void crossValidateModel(Classifier classifier, Instances data,
    int numFolds, Random random, Object... forPredictionsPrinting)
    throws Exception {

    // Make a copy of the data we can reorder
    data = new Instances(data);
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
      data.stratify(numFolds);
    }

    // We assume that the first element is a StringBuffer, the second a Range
    // (attributes
    // to output) and the third a Boolean (whether or not to output a
    // distribution instead
    // of just a classification)
    if (forPredictionsPrinting.length > 0) {
      // print the header first
      StringBuffer buff = (StringBuffer) forPredictionsPrinting[0];
      Range attsToOutput = (Range) forPredictionsPrinting[1];
      boolean printDist = ((Boolean) forPredictionsPrinting[2]).booleanValue();
      printClassificationsHeader(data, attsToOutput, printDist, buff);
    }

    // Do the folds
    for (int i = 0; i < numFolds; i++) {
      Instances train = data.trainCV(numFolds, i, random);
      setPriors(train);
      Classifier copiedClassifier = Classifier.makeCopy(classifier);
      copiedClassifier.buildClassifier(train);
      Instances test = data.testCV(numFolds, i);
      evaluateModel(copiedClassifier, test, forPredictionsPrinting);
    }
    m_NumFolds = numFolds;
  }

 

輸出結果截圖:測試

更新中。。。ui

 

 

 

libsvm 下載地址 https://github.com/cjlin1/libsvmlua

    github地址   https://github.com/cjlin1/libsvmspa

相關文章
相關標籤/搜索