Spark MLlib 的官方例子裏面提供的數據大部分是 libsvm 格式的。這實際上是一種很是蛋疼的文件格式,和常見的二維表格形式相去甚遠,下圖是裏面的一個例子:java
libsvm 文件的基本格式以下:python
<label> <index1>:<value1> <index2>:<value2>…
git
label 爲類別標識,index 爲特徵序號,value 爲特徵取值。如上圖中第一行中 0
爲標籤,128:51
表示第 128 個特徵取值爲 51 。github
Spark 當然提供了讀取 libsvm 文件的API,然而若是想把這些數據放到別的庫 (好比scikit-learn) 中使用,就不得不面臨一個格式轉換的問題了。因爲 CSV 文件是廣大人民羣衆喜聞樂見的文件格式,所以分別用 Python 和Java 寫一個程序來進行轉換。我在網上查閱了一下,基本上全是 csv 轉 libsvm,不多有 libsvm 轉 csv 的,惟一的一個是 phraug
庫中的libsvm2csv.py
。但這個實現有兩個缺點: 一個是須要事先指定維度; 另外一個是像上圖中的特徵序號是 128 - 658
,這樣轉換完以後 0 - 127
維的特徵全爲 0,就顯得多餘了,而比較好的作法是將全爲 0 的特徵列一併去除。下面是基於 Python 的實現:app
import sys import csv import numpy as np def empty_table(input_file): # 創建空表格, 維數爲原數據集中最大特徵維數 max_feature = 0 count = 0 with open(input_file, 'r', newline='') as f: reader = csv.reader(f, delimiter=" ") for line in reader: count += 1 for i in line: num = int(i.split(":")[0]) if num > max_feature: max_feature = num return np.zeros((count, max_feature + 1)) def write(input_file, output_file, table): with open(input_file, 'r', newline='') as f: reader = csv.reader(f, delimiter=" ") for c, line in enumerate(reader): label = line.pop(0) table[c, 0] = label if line[-1].strip() == '': line.pop(-1) line = map(lambda x : tuple(x.split(":")), line) for i, v in line: i = int(i) table[c, i] = v delete_col = [] for col in range(table.shape[1]): if not any(table[:, col]): delete_col.append(col) table = np.delete(table, delete_col, axis=1) # 刪除全 0 列 with open(output_file, 'w') as f: writer = csv.writer(f) for line in table: writer.writerow(line) if __name__ == "__main__": input_file = sys.argv[1] output_file = sys.argv[2] table = empty_table(input_file) write(input_file, output_file, table)
如下基於 Java 來實現,不得不說 Java 因爲沒有 Numpy 這類庫的存在,寫起來要繁瑣得多。spa
import java.io.*; import java.util.ArrayList; import java.util.Arrays; import java.util.List; public class LibsvmToCsv { public static void main(String[] args) throws IOException { String src = args[0]; String dest = args[1]; double[][] table = EmptyTable(src); double[][] newcsv = NewCsv(table, src); write(newcsv, dest); } // 創建空表格, 維數爲原數據集中最大特徵維數 public static double[][] EmptyTable(String src) throws IOException { int maxFeatures = 0, count = 0; File f = new File(src); BufferedReader br = new BufferedReader(new FileReader(f)); String temp = null; while ((temp = br.readLine()) != null){ count++; for (String pair : temp.split(" ")){ int num = Integer.parseInt(pair.split(":")[0]); if (num > maxFeatures){ maxFeatures = num; } } } double[][] emptyTable = new double[count][maxFeatures + 1]; return emptyTable; } public static double[][] NewCsv(double[][] newTable, String src) throws IOException { BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(src))); String temp = null; int count = 0; while ((temp = br.readLine()) != null){ String[] array = temp.split(" "); double label = Integer.parseInt(array[0]); for (String pair : Arrays.copyOfRange(array, 1, array.length)){ String[] pairs = pair.split(":"); int index = Integer.parseInt(pairs[0]); double value = Double.parseDouble(pairs[1]); newTable[count][index] = value; } newTable[count][0] = label; count++; } List<Integer> deleteCol = new ArrayList<>(); // 要刪除的全 0 列 int deleteColNum = 0; coll: for (int col = 0; col < newTable[0].length; col++){ int zeroCount = 0; for (int row = 0; row < newTable.length; row++){ if (newTable[row][col] != 0.0){ continue coll; // 如有一個值不爲 0, 繼續判斷下一列 } else { zeroCount++; } } if (zeroCount == newTable.length){ deleteCol.add(col); deleteColNum++; } } int newColNum = newTable[0].length - deleteColNum; double[][] newCsv = new double[count][newColNum]; // 新的不帶全 0 列的空表格 int newCol = 0; colll: for (int col = 0; col < newTable[0].length; col++){ for (int dCol : deleteCol){ if (col == dCol){ continue colll; } } for (int row = 0; row < newTable.length; row++){ newCsv[row][newCol] = newTable[row][col]; } newCol++; } return newCsv; } public static void write(double[][] table, String path) throws FileNotFoundException { BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(path))); try{ for (double[] row : table){ int countComma = 0; for (double c : row){ countComma ++; bw.write(String.valueOf(c)); if (countComma <= row.length - 1){ bw.append(','); } } bw.flush(); bw.newLine(); } } catch (IOException e){ e.printStackTrace(); } finally { try{ if (bw != null){ bw.close(); } } catch (IOException e){ e.printStackTrace(); } } } }
/code