使用 DL4J 訓練中文詞向量
[TOC]java
1 預處理
對中文語料的預處理,主要包括:分詞、去停用詞以及一些根據實際場景制定的規則。apache
package ai.mole.test; import org.ansj.domain.Term; import org.ansj.splitWord.analysis.ToAnalysis; import org.nlpcn.commons.lang.tire.domain.Forest; import org.nlpcn.commons.lang.tire.library.Library; import java.io.*; import java.util.LinkedList; import java.util.List; import java.util.regex.Pattern; public class Preprocess { private static final Pattern NUMERIC_PATTERN = Pattern.compile("^[.\\d]+$"); private static final Pattern ENGLISH_WORD_PATTERN = Pattern.compile("^[a-z]+$"); public static void main(String[] args) { String inPath1 = "D:\\MyData\\XUGP3\\Desktop\\測試分詞\\test1.txt"; String inPath2 = "D:\\MyData\\XUGP3\\Desktop\\測試分詞\\stop_words.txt"; String outPath = "D:\\MyData\\XUGP3\\Desktop\\測試分詞\\result1.txt"; String encoding = "utf-8"; PrintWriter writer = null; Forest forest = null; try { writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(outPath), encoding)); forest = Library.makeForest(Test.class.getResourceAsStream("/library/userLibrary.dic")); List<String> lineList = IOUtil.readLines(new FileInputStream(inPath1), encoding); List<String> stopWordList = IOUtil.readLines(new FileInputStream(inPath2), encoding); for (String line : lineList) { String[] cols = line.split("\\t", -1); if (cols.length < 2) { continue; } String text = cols[0].trim().toLowerCase() + " " + cols[1].trim().toLowerCase(); // 分詞 List<Term> termList = ToAnalysis.parse(text, forest).getTerms(); List<String> wordList = new LinkedList<>(); for (Term term : termList) { String word = term.getName(); if (word.length() < 2) { continue; } if (stopWordList.contains(word)) { continue; } if (isNumeric(word)) { continue; } if (isEnglishWord(word)) { continue; } wordList.add(word); } if (wordList.size() > 5) { String outStr = listToLine(wordList); writer.println(outStr); } } } catch (FileNotFoundException e) { System.out.println("The file does not exist or the path is not correct!!!"); System.exit(-1); } catch (UnsupportedEncodingException e) { System.out.println("Does not support the current character set!!!"); } catch (IOException e) { e.printStackTrace(); } catch (Exception e) { e.printStackTrace(); } finally { if (writer != null) { writer.close(); } } } private static boolean isNumeric(String text) { return NUMERIC_PATTERN.matcher(text).matches(); } private static boolean isEnglishWord(String text) { return ENGLISH_WORD_PATTERN.matcher(text).matches(); } private static String listToLine(List<String> list) { StringBuilder sb = new StringBuilder(); for (int i=0; i<list.size(); i++) { sb.append(list.get(i)); if (i != list.size()-1) { sb.append(" "); } } return sb.toString(); } }
2 訓練
訓練的代碼很是簡單,能夠直接看官網的教程,至於 word2vec 的原理能夠看皮提果的博文。app
package ai.mole.test; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.word2vec.Word2Vec; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.io.IOException; import java.util.Collection; public class TrainWord2VecModel { private static Logger log = LoggerFactory.getLogger(TrainWord2VecModel.class); public static void main(String[] args) throws IOException { String corpusPath = "/data/analyze/xgp/words.txt"; String vectorsPath = "/data/analyze/xgp/word_vectors.txt"; log.info("Start Training..."); long st = System.currentTimeMillis(); log.info("Load & vectorize sentences..."); SentenceIterator iter = new BasicLineIterator(new File(corpusPath)); TokenizerFactory t = new DefaultTokenizerFactory(); // t.setTokenPreProcessor(new CommonPreprocessor()); log.info("Building model..."); Word2Vec vec = new Word2Vec.Builder() .minWordFrequency(50) .iterations(1) .epochs(100) .layerSize(500) .seed(42) .windowSize(5) .iterate(iter) .tokenizerFactory(t) .build(); log.info("Fitting word2vec model..."); vec.fit(); log.info("Writing word vectors to text file..."); // WordVectorSerializer.writeWord2VecModel(vec, vectorsPath); WordVectorSerializer.writeWordVectors(vec, vectorsPath); log.info("Closest words:"); Collection<String> bydWordList = vec.wordsNearest("比亞迪", 10); Collection<String> changanWordList = vec.wordsNearest("長安", 10); System.out.print(bydWordList); System.out.println(changanWordList); log.info("10 words closest to '比亞迪': {}", bydWordList); log.info("10 words closest to '長安': {}", changanWordList); long et = System.currentTimeMillis(); log.info("Training is completed, and the time taken is " + (et-st) + " ms."); System.out.println("Training is completed, and the time taken is " + (et-st) + " ms."); } }
3 調用
調用訓練好的詞向量也很是簡單,只須要調用 WordVectorSerializer
類的靜態方法 readWord2VecModel
就能夠了,提供的輸入參數就是訓練好的詞向量路徑。dom
Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel("D:\\MyData\\XUGP3\\Desktop\\測試分詞\\vectors.txt"); Collection<String> bydWordList = word2Vec.wordsNearest("比亞迪", 10); Collection<String> changanWordList = word2Vec.wordsNearest("長安", 10); System.out.println(bydWordList); System.out.println(changanWordList);
附錄 - maven 依賴
<dependencies> <dependency> <groupId>org.apdplat</groupId> <artifactId>word</artifactId> <version>1.3</version> </dependency> <!-- ND4J backend. You need one in every DL4J project. Normally define artifactId as either "nd4j-native-platform" or "nd4j-cuda-7.5-platform" --> <dependency> <groupId>org.nd4j</groupId> <artifactId>${nd4j.backend}</artifactId> <version>${nd4j.version}</version> </dependency> <!-- Core DL4J functionality --> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-core</artifactId> <version>${dl4j.version}</version> </dependency> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-nlp</artifactId> <version>${dl4j.version}</version> </dependency> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-zoo</artifactId> <version>${dl4j.version}</version> </dependency> <!-- deeplearning4j-ui is used for visualization: see http://deeplearning4j.org/visualization --> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-ui_${scala.binary.version}</artifactId> <version>${dl4j.version}</version> </dependency> <!-- ParallelWrapper & ParallelInference live here --> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-parallel-wrapper_${scala.binary.version}</artifactId> <version>${dl4j.version}</version> </dependency> <!-- Next 2: used for MapFileConversion Example. Note you need *both* together --> <dependency> <groupId>org.datavec</groupId> <artifactId>datavec-hadoop</artifactId> <version>${datavec.version}</version> </dependency> <dependency> <groupId>org.apache.hadoop</groupId> <artifactId>hadoop-common</artifactId> <version>${hadoop.version}</version> </dependency> <!-- Arbiter - used for hyperparameter optimization (grid/random search) --> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>arbiter-deeplearning4j</artifactId> <version>${arbiter.version}</version> </dependency> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>arbiter-ui_2.11</artifactId> <version>${arbiter.version}</version> </dependency> <!-- datavec-data-codec: used only in video example for loading video data --> <dependency> <artifactId>datavec-data-codec</artifactId> <groupId>org.datavec</groupId> <version>${datavec.version}</version> </dependency> </dependencies>