做者:LogMgit
本文原載於 https://segmentfault.com/u/logm/articles ,不容許轉載~github
FastText 源碼:https://github.com/facebookre...segmentfault
本文對應的源碼版本:Commits on Jun 27 2019, 979d8a9ac99c731d653843890c2364ade0f7d9d3
數組
FastText 論文:函數
[1] P. Bojanowski, E. Grave, A. Joulin, T. Mikolov, Enriching Word Vectors with Subword Informationui
[2] A. Joulin, E. Grave, P. Bojanowski, T. Mikolov, Bag of Tricks for Efficient Text Classificationcode
FastText 的論文寫的比較簡單,有些細節不明白,網上也查不到,所幸直接撕源碼。orm
FastText 的"分類器"功能是用的最多的,因此先從"分類器的predict"開始挖。token
先看程序入口的 main
函數,ok,是調用了 predict
函數。ip
// 文件:src/main.cc // 行數:403 int main(int argc, char** argv) { std::vector<std::string> args(argv, argv + argc); if (args.size() < 2) { printUsage(); exit(EXIT_FAILURE); } std::string command(args[1]); if (command == "skipgram" || command == "cbow" || command == "supervised") { train(args); } else if (command == "test" || command == "test-label") { test(args); } else if (command == "quantize") { quantize(args); } else if (command == "print-word-vectors") { printWordVectors(args); } else if (command == "print-sentence-vectors") { printSentenceVectors(args); } else if (command == "print-ngrams") { printNgrams(args); } else if (command == "nn") { nn(args); } else if (command == "analogies") { analogies(args); } else if (command == "predict" || command == "predict-prob") { predict(args); // 這句是咱們想要的 } else if (command == "dump") { dump(args); } else { printUsage(); exit(EXIT_FAILURE); } return 0; }
再看 predict
函數,預處理的代碼不用管,直接看 predict 的那行,調用了 FastText::predictLine
。這裏注意下,這是個 while
循環,因此FastText::predictLine
這個函數每次只處理一行。
// 文件:src/main.cc // 行數:205 void predict(const std::vector<std::string>& args) { if (args.size() < 4 || args.size() > 6) { printPredictUsage(); exit(EXIT_FAILURE); } int32_t k = 1; real threshold = 0.0; if (args.size() > 4) { k = std::stoi(args[4]); if (args.size() == 6) { threshold = std::stof(args[5]); } } bool printProb = args[1] == "predict-prob"; FastText fasttext; fasttext.loadModel(std::string(args[2])); std::ifstream ifs; std::string infile(args[3]); bool inputIsStdIn = infile == "-"; if (!inputIsStdIn) { ifs.open(infile); if (!inputIsStdIn && !ifs.is_open()) { std::cerr << "Input file cannot be opened!" << std::endl; exit(EXIT_FAILURE); } } std::istream& in = inputIsStdIn ? std::cin : ifs; std::vector<std::pair<real, std::string>> predictions; while (fasttext.predictLine(in, predictions, k, threshold)) { // 這句是重點 printPredictions(predictions, printProb, false); } if (ifs.is_open()) { ifs.close(); } exit(0); }
再看 FastText::predictLine
,注意這邊有兩個重點。
// 文件:src/fasttext.cc // 行數:451 bool FastText::predictLine( std::istream& in, std::vector<std::pair<real, std::string>>& predictions, int32_t k, real threshold) const { predictions.clear(); if (in.peek() == EOF) { return false; } std::vector<int32_t> words, labels; dict_->getLine(in, words, labels); // 這句是第一個重點 Predictions linePredictions; predict(k, words, linePredictions, threshold); // 這句是第二個重點 for (const auto& p : linePredictions) { predictions.push_back( std::make_pair(std::exp(p.first), dict_->getLabel(p.second))); } return true; }
先看第一個重點,getLine
函數實際上是 Dictionary::getLine
,定義在src/dictionary.cc
。
這段代碼的乾貨度仍是很高的,裏面有兩個重點,Dictionary::addSubwords
和 Dictionary::addWordNgrams
,之後會講。這邊只要知道整個函數把讀到的這一行的每一個Id(包括詞語的id,SubWords的Id,WordNgram的Id),存到了數組 words
中。
// 文件:src/dictionary.cc // 行數:378 int32_t Dictionary::getLine( std::istream& in, std::vector<int32_t>& words, std::vector<int32_t>& labels) const { std::vector<int32_t> word_hashes; std::string token; int32_t ntokens = 0; reset(in); words.clear(); labels.clear(); while (readWord(in, token)) { // `token` 是讀到的一個詞語,若是讀到一行的行尾,則返回`EOF` uint32_t h = hash(token); // 找到這個詞語位於哪一個hash桶 int32_t wid = getId(token, h); // 在hash桶中找到這個詞語的Id,若是負數就是沒找到對應的Id entry_type type = wid < 0 ? getType(token) : getType(wid); // 若是沒找到對應Id,則有多是label,`getType`裏會處理 ntokens++; if (type == entry_type::word) { addSubwords(words, token, wid); // 重點1,之後會講 word_hashes.push_back(h); } else if (type == entry_type::label && wid >= 0) { labels.push_back(wid - nwords_); } if (token == EOS) { break; } } addWordNgrams(words, word_hashes, args_->wordNgrams); // 重點2,之後會講 return ntokens; }
再來看第二個重點, FastText::predict
函數,重點是 Model::predict
函數。
// 文件:src/fasttext.cc // 行數:437 void FastText::predict( int32_t k, const std::vector<int32_t>& words, Predictions& predictions, real threshold) const { if (words.empty()) { return; } Model::State state(args_->dim, dict_->nlabels(), 0); if (args_->model != model_name::sup) { throw std::invalid_argument("Model needs to be supervised for prediction!"); } model_->predict(words, k, threshold, predictions, state); // 這句是重點 }
來到 Model::predict
,有兩個重點.
其中 Loss::predict
是將 hidden 層的輸出結果進行 softmax 後獲得最終機率最大的k個類別,"分類器的predict" 用的是經典的softmax,因此代碼也比較簡單。而若是是"分類器的train" 則涉及到 Hierarchical SoftmaxLoss
和 NegativeSamplingLoss
等一些加速手段,比較複雜,之後有機會再講。
// 文件:src/model.cc // 行數:53 void Model::predict( const std::vector<int32_t>& input, int32_t k, real threshold, Predictions& heap, State& state) const { if (k == Model::kUnlimitedPredictions) { k = wo_->size(0); // output size } else if (k <= 0) { throw std::invalid_argument("k needs to be 1 or higher!"); } heap.reserve(k + 1); computeHidden(input, state); // 重點1 loss_->predict(k, threshold, heap, state); // 重點2,之後再講 }
咱們再來看另外一個重點,Model::computeHidden
函數。
Model::computeHidden
函數理解起來比較簡單,注意這裏的 input
就是前面的 words
,是一系列id組成的數組(包括詞語的id,SubWords的Id,WordNgram的Id),把這些求和,而後取平均。
固然有些小夥伴可能有點疑問,Vector::addRow
爲何是求和,這個之後再講吧。
// 文件:src/model.cc // 行數:43 void Model::computeHidden(const std::vector<int32_t>& input, State& state) const { Vector& hidden = state.hidden; hidden.zero(); for (auto it = input.cbegin(); it != input.cend(); ++it) { hidden.addRow(*wi_, *it); // 求和 } hidden.mul(1.0 / input.size()); // 而後取平均 }
至此,FastText裏面的"分類器的predict"的大體流程講完了,其餘的,如"分類器的train"和"詞向量"的源碼也是相似的方法來閱讀。
這裏面有幾段代碼沒有詳細敘述:Dictionary::addSubwords
、Dictionary::addWordNgrams
、Vector::addRow
以及訓練時softmax的加速
,先把坑留着,之後有時間再填。