LibSVM文件轉換爲csv格式


Spark MLlib 的官方例子裏面提供的數據大部分是 libsvm 格式的。這實際上是一種很是蛋疼的文件格式,和常見的二維表格形式相去甚遠,下圖是裏面的一個例子:java


libsvm 文件的基本格式以下:python

<label> <index1>:<value1> <index2>:<value2>…git

label 爲類別標識,index 爲特徵序號,value 爲特徵取值。如上圖中第一行中 0 爲標籤,128:51 表示第 128 個特徵取值爲 51 。github

Spark 當然提供了讀取 libsvm 文件的API,然而若是想把這些數據放到別的庫 (好比scikit-learn) 中使用,就不得不面臨一個格式轉換的問題了。因爲 CSV 文件是廣大人民羣衆喜聞樂見的文件格式,所以分別用 Python 和Java 寫一個程序來進行轉換。我在網上查閱了一下,基本上全是 csv 轉 libsvm,不多有 libsvm 轉 csv 的,惟一的一個是 phraug庫中的libsvm2csv.py 。但這個實現有兩個缺點: 一個是須要事先指定維度; 另外一個是像上圖中的特徵序號是 128 - 658 ,這樣轉換完以後 0 - 127 維的特徵全爲 0,就顯得多餘了,而比較好的作法是將全爲 0 的特徵列一併去除。下面是基於 Python 的實現:app


import sys
import csv
import numpy as np

def empty_table(input_file):  # 創建空表格, 維數爲原數據集中最大特徵維數
    max_feature = 0
    count = 0
    with open(input_file, 'r', newline='') as f:
        reader = csv.reader(f, delimiter=" ")
        for line in reader:
            count += 1
            for i in line:
                num = int(i.split(":")[0])
                if num > max_feature:
                    max_feature = num
                    
    return np.zeros((count, max_feature + 1))

def write(input_file, output_file, table):
    with open(input_file, 'r', newline='') as f:
        reader = csv.reader(f, delimiter=" ")
        for c, line in enumerate(reader):
            label = line.pop(0)
            table[c, 0] = label
            if line[-1].strip() == '':
                line.pop(-1)

            line = map(lambda x : tuple(x.split(":")), line)
            for i, v in line:
                i = int(i)
                table[c, i] = v

    delete_col = []
    for col in range(table.shape[1]):
        if not any(table[:, col]):
            delete_col.append(col)
    
    table = np.delete(table, delete_col, axis=1)  # 刪除全 0 列
    with open(output_file, 'w') as f:
        writer = csv.writer(f)
        for line in table:
            writer.writerow(line)


if __name__ == "__main__":
    input_file = sys.argv[1]
    output_file = sys.argv[2]
    table = empty_table(input_file)
    write(input_file, output_file, table)


如下基於 Java 來實現,不得不說 Java 因爲沒有 Numpy 這類庫的存在,寫起來要繁瑣得多。spa

import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class LibsvmToCsv {
    public static void main(String[] args) throws IOException {

        String src = args[0];
        String dest = args[1];

        double[][] table = EmptyTable(src);
        double[][] newcsv = NewCsv(table, src);
        write(newcsv, dest);
    }

    // 創建空表格, 維數爲原數據集中最大特徵維數
    public static double[][] EmptyTable(String src) throws IOException {
        int maxFeatures = 0, count = 0;
        File f = new File(src);
        BufferedReader br = new BufferedReader(new FileReader(f));
        String temp = null;
        while ((temp = br.readLine()) != null){
            count++;
            for (String pair : temp.split(" ")){
                int num = Integer.parseInt(pair.split(":")[0]);
                if (num > maxFeatures){
                    maxFeatures = num;
                }
            }
        }
        double[][] emptyTable = new double[count][maxFeatures + 1];
        return emptyTable;
    }

    public static double[][] NewCsv(double[][] newTable, String src) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(src)));
        String temp = null;
        int count = 0;
        while ((temp = br.readLine()) != null){
            String[] array = temp.split(" ");
            double label = Integer.parseInt(array[0]);
            for (String pair : Arrays.copyOfRange(array, 1, array.length)){
                String[] pairs = pair.split(":");
                int index = Integer.parseInt(pairs[0]);
                double value = Double.parseDouble(pairs[1]);
                newTable[count][index] = value;
            }
            newTable[count][0] = label;
            count++;
        }

        List<Integer> deleteCol = new ArrayList<>();  // 要刪除的全 0 列
        int deleteColNum = 0;

        coll:
        for (int col = 0; col < newTable[0].length; col++){
            int zeroCount = 0;
            for (int row = 0; row < newTable.length; row++){
                if (newTable[row][col] != 0.0){
                    continue coll;  // 如有一個值不爲 0, 繼續判斷下一列
                } else {
                    zeroCount++;
                }
            }

            if (zeroCount == newTable.length){
                deleteCol.add(col);
                deleteColNum++;
            }
        }

        int newColNum =  newTable[0].length - deleteColNum;
        double[][] newCsv = new double[count][newColNum];  // 新的不帶全 0 列的空表格
        int newCol = 0;

        colll:
        for (int col = 0; col < newTable[0].length; col++){
            for (int dCol : deleteCol){
                if (col == dCol){
                    continue colll;
                }
            }

            for (int row = 0; row < newTable.length; row++){
                newCsv[row][newCol] = newTable[row][col];
            }
            newCol++;
        }
        return newCsv;
    }

    public static void write(double[][] table, String path) throws FileNotFoundException {
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(path)));
        try{
            for (double[] row : table){
                int countComma = 0;
                for (double c : row){
                    countComma ++;
                    bw.write(String.valueOf(c));
                    if (countComma <= row.length - 1){
                        bw.append(',');
                    }
                }
                bw.flush();
                bw.newLine();
            }
        } catch (IOException e){
            e.printStackTrace();
        } finally {
            try{
                if (bw != null){
                    bw.close();
                }
            } catch (IOException e){
                e.printStackTrace();
            }
        }
    }
}





/code

相關文章
相關標籤/搜索