代碼放在github上:click mejava
數據集爲英文語料集,一共包含20種類別的郵件,除了類別soc.religion.christian的郵件數爲997之外每一個類別的郵件數都是1000。每份郵件內部包含發送者,接受者,正文等信息。git
數據預處理階段採用了幾種方案進行測試github
直接將郵件內容按空格分詞shell
使用stanford corenlp進行分詞,而後使用停詞表過濾分詞結果windows
使用stanford corenlp進行分詞,並根據詞性和停詞表過濾分詞結果網絡
綜合上面三種方案,測試結果最好的是方案二的預處理方式。將全部的郵件預處理以後寫入一個文件中,文件每行對應一封郵件,形式如"類別\t按空格分隔的郵件分詞"maven
comp.os.ms-windows.misc I search Ms-Windows logo picture start Windows misc.forsale ITEMS for SALEI offer item I reserve the right refuse offer Howard Miller comp.sys.ibm.pc.hardware I hd bad suggest inadequate power supply how wattage
基於spark的pipeline構建端到端的分類模型oop
將數據預處理後獲得的文件上傳到hdfs上,spark從hdfs上讀取文本數據並轉換成DataFrame學習
爲DataFrame的郵件類別列創建索引,而後將DataFrame做爲Word2Vec的輸入獲取句子的向量表示測試
句子向量輸入到含有2層隱藏層的多層感知機(MLP)中進行分類學習
將預測結果的索引列轉換成可讀的郵件類別標籤
將數據集隨機劃分紅8:2,80%的數據做爲訓練集,20%的數據做爲測試集。通過合理的調參,在測試集上的accuracy和F1 score能夠達到90.5%左右,關鍵參數設置以下
// Word2Vec超參 final val W2V_MAX_ITER = 5 // Word2Vec迭代次數 final val EMBEDDING_SIZE = 128 // 詞向量長度 final val MIN_COUNT = 1 // default: 5, 詞彙表閾值即至少出現min_count次才放入詞彙表中 final val WINDOW_SIZE = 5 // default: 5, 上下文窗口大小[-WINDOW_SIZE,WINDOW_SIZE] // MLP超參 final val MLP_MAX_ITER = 300 // MLP迭代次數 final val SEED = 1234L // 隨機數種子,初始化網絡權重用 final val HIDDEN1_SIZE = 64 // 第一層隱藏層節點數 final val HIDDEN2_SIZE = 32 // 第二層隱藏層節點數 final val LABEL_SIZE = 20 // 輸出層節點數
郵件預測結果輸出在hdfs上,文件內容每行的de左邊是真實label,右邊是預測label
hadoop-2.7.5
spark-2.3.0
stanford corenlp 3.9.2
Maven項目文件結構以下
src/main/scala下爲源代碼,其中Segment.java和EnglishSegment.java用於英文分詞,DataPreprocess.scala基於分詞做數據預處理,MailClassifier.scala對應郵件分類模型。input下爲數據集,output下爲數據預處理結果MailCollection和預測結果prediction,target下爲maven打好的jar包Mail.jar以及運行腳本submit.sh,pom.xml爲maven配置。
將數據集20_newsgroup放在input目錄下,確保pom.xml中的依賴包都知足之後運行DataPreprocess獲得預處理的結果MailCollection輸出到output目錄下。啓動hadoop的hdfs,將MailCollection上傳到hdfs上以便spark讀取。而後啓動spark,命令行下進入到target路徑下運行./submit.sh提交任務,submit.sh內容以下
spark-submit --class MailClassifier --master spark://master:7077 --conf spark.driver.memory=10g --conf spark.executor.memory=4g --conf spark.executor.cores=2 --conf spark.kryoserializer.buffer=512m --conf spark.kryoserializer.buffer.max=1g Mail.jar input/MailCollection output
運行MailClassifier須要兩個命令行參數,其中input/MailCollection爲上傳到hdfs上的路徑名,output爲預測結果輸出到hdfs上的路徑名,提交任務前確保輸出路徑在hdfs上不存在,不然程序會刪除輸出輸出路徑以確保程序正確運行。