在Ignite中使用k-均值聚類算法

在本系列前面的文章中,簡單介紹了一下Ignite的k-最近鄰(k-NN)分類算法,下面會嘗試另外一個機器學習算法,即便用泰坦尼克數據集介紹k-均值聚類算法。正好,Kaggle提供了CSV格式的數據集,而要分析的是兩個分類:即乘客是否倖存。java

爲了將數據轉換爲Ignite支持的格式,前期須要作一些清理和格式化的工做,CSV文件中包含若干個列,以下:算法

  • 乘客Id
  • 倖存(0:否,1:是)
  • 船票席別(1:一,2:二,3:三)
  • 乘客姓名
  • 性別
  • 年齡
  • 泰坦尼克號上的兄弟/姐妹數
  • 泰坦尼克號上的父母/子女數
  • 船票號碼
  • 票價
  • 客艙號碼
  • 登船港口(C=瑟堡,Q=皇后鎮,S=南安普頓)

所以首先要作的是,刪除任何和特定乘客有關的、和生存無關的列,以下:app

  • 乘客Id
  • 乘客姓名
  • 船票號碼
  • 客艙號碼

接下來會刪除任何數據有缺失的行,好比年齡或者登船港口,能夠對這些值進行歸類,可是爲了進行初步的分析,會刪除缺失的值。機器學習

最後會將部分字段轉換爲數值類型,好比性別會被轉換爲:ide

  • 0:女
  • 1:男

登船港口會被轉換爲:學習

  • 0:Q(皇后鎮)
  • 1:C(瑟堡)
  • 2:S(南安普頓)

最終的數據集由以下的列組成:測試

  • 船票席別
  • 性別
  • 年齡
  • 泰坦尼克號上的兄弟/姐妹數
  • 泰坦尼克號上的父母/子女數
  • 票價
  • 登船港口
  • 倖存

能夠看到,倖存列已被移到最後。idea

下一步會將數據拆分爲訓練數據(80%)和測試數據(20%),和前文同樣,仍是使用Scikit-learn來執行這個拆分任務。.net

準備好訓練和測試數據後,就能夠編寫應用了,本文的算法是:code

  1. 讀取訓練數據和測試數據;
  2. 在Ignite中保存訓練數據和測試數據;
  3. 使用訓練數據擬合k-均值聚類模型;
  4. 將模型應用於測試數據;
  5. 肯定含混矩陣和模型的準確性。

讀取訓練數據和測試數據

經過下面的代碼,能夠從CSV文件中讀取數據:

private static void loadData(String fileName, IgniteCache<Integer, TitanicObservation> cache)
        throws FileNotFoundException {

   Scanner scanner = new Scanner(new File(fileName));

   int cnt = 0;
   while (scanner.hasNextLine()) {
      String row = scanner.nextLine();
      String[] cells = row.split(",");
      double[] features = new double[cells.length - 1];

      for (int i = 0; i < cells.length - 1; i++)
         features[i] = Double.valueOf(cells[i]);
      double survivedClass = Double.valueOf(cells[cells.length - 1]);

      cache.put(cnt++, new TitanicObservation(features, survivedClass));
   }
}

該代碼簡單地一行行的讀取數據,而後對於每一行,使用CSV的分隔符拆分出字段,每一個字段以後將轉換成double類型而且存入Ignite。

將訓練數據和測試數據存入Ignite

前面的代碼將數據存入Ignite,要使用這個代碼,首先要建立Ignite存儲,以下:

IgniteCache<Integer, TitanicObservation> trainData = getCache(ignite, "TITANIC_TRAIN");
IgniteCache<Integer, TitanicObservation> testData = getCache(ignite, "TITANIC_TEST");
loadData("src/main/resources/titanic-train.csv", trainData);
loadData("src/main/resources/titanic-test.csv", testData);

getCache()的實現以下:

private static IgniteCache<Integer, TitanicObservation> getCache(Ignite ignite, String cacheName) {

   CacheConfiguration<Integer, TitanicObservation> cacheConfiguration = new CacheConfiguration<>();
   cacheConfiguration.setName(cacheName);
   cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));

   IgniteCache<Integer, TitanicObservation> cache = ignite.createCache(cacheConfiguration);

   return cache;
}

使用訓練數據擬合k-NN分類模型

數據存儲以後,能夠像下面這樣建立訓練器:

KMeansTrainer trainer = new KMeansTrainer()
        .withK(2)
        .withDistance(new EuclideanDistance())
        .withSeed(123L);

這裏k的值配置爲2,表示有2個簇(倖存和未倖存),對於距離測量,能夠有多個選擇,好比歐幾里得、海明或曼哈頓,在本例中會使用歐幾里得,另外,種子值賦值爲123。

而後擬合訓練數據,以下:

KMeansModel mdl = trainer.fit(
        ignite,
        trainData,
        (k, v) -> v.getFeatures(),        
// Feature extractor.

        (k, v) -> v.getSurvivedClass()    
// Label extractor.

);

Ignite將數據保存爲鍵-值(K-V)格式,所以上面的代碼使用了值部分,目標值是Survived類,特徵在其它列中。

將模型應用於測試數據

下一步,就能夠用訓練好的分類模型測試測試數據了,能夠這樣作:

int amountOfErrors = 0;
int totalAmount = 0;
int[][] confusionMtx = {{0, 0}, {0, 0}};

try (QueryCursor<Cache.Entry<Integer, TitanicObservation>> cursor = testData.query(new ScanQuery<>())) {
   for (Cache.Entry<Integer, TitanicObservation> testEntry : cursor) {
      TitanicObservation observation = testEntry.getValue();

      double groundTruth = observation.getSurvivedClass();
      double prediction = mdl.apply(new DenseLocalOnHeapVector(observation.getFeatures()));

      totalAmount++;
      if ((int) groundTruth != (int) prediction)
         amountOfErrors++;

      int idx1 = (int) prediction;
      int idx2 = (int) groundTruth;

      confusionMtx[idx1][idx2]++;

      System.out.printf(">>> | %.4f\t | %.0f\t\t\t|\n", prediction, groundTruth);
   }
}

肯定含混矩陣和模型的準確性

下面,就能夠經過對測試數據中的真實分類和模型進行的分類進行對比,來確認模型的真確性。

代碼運行以後,輸出以下:

>>> Absolute amount of errors 56

>>> Accuracy 0.6084

>>> Precision 0.5865

>>> Recall 0.9873

>>> Confusion matrix is [[78, 55], [1, 9]]

這個初步的結果可不能夠改進?能夠嘗試的是對特徵的衡量,在Ignite和Scikit-learn中,可使用MinMaxScaler(),而後會給出以下的輸出:

>>> Absolute amount of errors 29

>>> Accuracy 0.7972

>>> Precision 0.8205

>>> Recall 0.8101

>>> Confusion matrix is [[64, 14], [15, 50]]

做爲進一步分析的一部分,還應該研究倖存與否和年齡和性別之間的關係。

總結

一般來講,k-均值聚類並不適合監督學習任務,可是若是分類很容易,這個方法仍是有效的。對於本例來講,關注的就是是否倖存。

相關文章
相關標籤/搜索