chainbin/nnet3-chain-train.cc node
int main(int argc, char *argv[]) { 網絡 ... app Nnet nnet; ide ReadKaldiObject(nnet_rxfilename, &nnet); 函數 bool ok; this { spa fst::StdVectorFst den_fst; 3d ReadFstKaldi(den_fst_rxfilename, &den_fst); 指針
//NnetChainTrainer讀取訓練參數opts、分母詞圖den_fst、神經網絡nnet code NnetChainTrainer trainer(opts, den_fst, &nnet); //SequentialNnetChainExampleReader以語句爲單位讀取樣本 SequentialNnetChainExampleReader example_reader(examples_rspecifier); for (; !example_reader.Done(); example_reader.Next()) //以句爲單位進行訓練 trainer.Train(example_reader.Value()); ok = trainer.PrintTotalStats(); }n ... WriteKaldiObject(nnet, nnet_wxfilename, binary_write); ... } |
nnet3/nnet-chain-training.cc
void NnetChainTrainer::Train(const NnetChainExample &chain_eg) { bool need_model_derivative = true; const NnetTrainerOptions &nnet_config = opts_.nnet_config; bool use_xent_regularization = (opts_.chain_config.xent_regularize != 0.0); ComputationRequest request; //This function takes a NnetChainExample and produces a ComputationRequest. GetChainComputationRequest(*nnet_, chain_eg, need_model_derivative, nnet_config.store_component_stats, use_xent_regularization, need_model_derivative, &request); //進行編譯,返回到結果的常量指針。 //返回的常量指針由CachingOptimizingCompiler NnetChainTrainer::compiler_全部 //若是編譯失敗,用std::shared_ptr<const NnetComputation>接收返回值 std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
if (nnet_config.backstitch_training_scale > 0.0 && num_minibatches_processed_ % nnet_config.backstitch_training_interval == srand_seed_ % nnet_config.backstitch_training_interval) { // backstitch training is incompatible with momentum > 0 KALDI_ASSERT(nnet_config.momentum == 0.0); FreezeNaturalGradient(true, delta_nnet_); bool is_backstitch_step1 = true; srand(srand_seed_ + num_minibatches_processed_); ResetGenerators(nnet_); TrainInternalBackstitch(chain_eg, *computation, is_backstitch_step1); FreezeNaturalGradient(false, delta_nnet_); // un-freeze natural gradient is_backstitch_step1 = false; srand(srand_seed_ + num_minibatches_processed_); ResetGenerators(nnet_); TrainInternalBackstitch(chain_eg, *computation, is_backstitch_step1); } else { // conventional training TrainInternal(chain_eg, *computation); }
num_minibatches_processed_++; } |
void NnetChainTrainer::TrainInternal(const NnetChainExample &eg, const NnetComputation &computation) { //NnetComputer類負責執行"computation"對象描述的計算。 //以如下順序調用: 構造函數 AcceptInput()【或AcceptInputs()】 Run() GetOutput() AcceptOutputDeriv()【若可用】 Run()【若是須要反向計算】 GetInputDeriv()【若可用】: NnetComputer computer(nnet_config.compute_config, computation, nnet_, delta_nnet_); computer.AcceptInputs(*nnet_, eg.inputs); //前向傳播,計算 computer.Run(); //該函數調用了GetOutput() this->ProcessOutputs(false, eg, &computer); //反向傳播,計算權重更新量delta_nnet_ computer.Run(); //根據L2正則化項,修改權重更新量delta_nnet_ ApplyL2Regularization(*nnet_, GetNumNvalues(eg.inputs, false) * nnet_config.l2_regularize_factor, delta_nnet_); //根據權重更新量delta_nnet_,更新神經網絡,上限爲nnet_config.max_param_change bool success = UpdateNnetWithMaxChange(*delta_nnet_, nnet_config.max_param_change, 1.0, 1.0 - nnet_config.momentum, nnet_, &num_max_change_per_component_applied_, &num_max_change_global_applied_);
|
void NnetChainTrainer::ProcessOutputs(bool is_backstitch_step2, const NnetChainExample &eg, NnetComputer *computer) { // normally the eg will have just one output named 'output', but // we don't assume this. // In backstitch training, the output-name with the "_backstitch" suffix is // the one computed after the first, backward step of backstitch. const std::string suffix = (is_backstitch_step2 ? "_backstitch" : ""); std::vector<NnetChainSupervision>::const_iterator iter = eg.outputs.begin(), end = eg.outputs.end(); for (; iter != end; ++iter) { //檢查每一個樣本的標籤是否與網絡相匹配 const NnetChainSupervision &sup = *iter; int32 node_index = nnet_->GetNodeIndex(sup.name); if (node_index < 0 || !nnet_->IsOutputNode(node_index)) KALDI_ERR << "Network has no output named " << sup.name;
const CuMatrixBase<BaseFloat> &nnet_output = computer->GetOutput(sup.name); CuMatrix<BaseFloat> nnet_output_deriv(nnet_output.NumRows(), nnet_output.NumCols(), kUndefined); //是否進行交叉熵正則化 bool use_xent = (opts_.chain_config.xent_regularize != 0.0); //從名爲"output-xent"的component-node獲取交叉熵的目標函數值 std::string xent_name = sup.name + "-xent"; // typically "output-xent". CuMatrix<BaseFloat> xent_deriv; //tot_objf,目標函數值,未包含L2正則化項,未包含交叉熵正則化項 //tot_l2_term,L2正則化項 //tot_weight,L2正則化項權重 BaseFloat tot_objf, tot_l2_term, tot_weight; //根據預測和標籤計算目標函數值及其梯度,計算交叉熵正則化項及其權重
//幀平滑-序列鑑別性準則
ComputeChainObjfAndDeriv(opts_.chain_config, den_graph_, sup.supervision, nnet_output, &tot_objf, &tot_l2_term, &tot_weight, &nnet_output_deriv, (use_xent ? &xent_deriv : NULL));
//更新梯度統計量 if (use_xent) { // 從神經網絡中獲取交叉熵output-node的輸出 const CuMatrixBase<BaseFloat> &xent_output = computer->GetOutput( xent_name); /* 此時,xent_deriv是MMI準則函數的分子後驗/分子錯誤信號。
/* BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans); objf_info_[xent_name + suffix].UpdateStats(xent_name + suffix, opts_.nnet_config.print_interval, num_minibatches_processed_, tot_weight, xent_objf); } //乘以梯度權重 if (opts_.apply_deriv_weights && sup.deriv_weights.Dim() != 0) { CuVector<BaseFloat> cu_deriv_weights(sup.deriv_weights); nnet_output_deriv.MulRowsVec(cu_deriv_weights); if (use_xent) //xent_deriv=diag(cu_deriv_weights)*xent_deriv //用cu_deriv_weights[i]對xent_deriv的第i行進行縮放 xent_deriv.MulRowsVec(cu_deriv_weights); } //計算器接收梯度 computer->AcceptInput(sup.name, &nnet_output_deriv);
objf_info_[sup.name + suffix].UpdateStats(sup.name + suffix, opts_.nnet_config.print_interval, num_minibatches_processed_, tot_weight, tot_objf, tot_l2_term);
if (use_xent) { //以交叉熵正則化因子進行縮放 xent_deriv.Scale(opts_.chain_config.xent_regularize); //接收交叉熵正則化的梯度 computer->AcceptInput(xent_name, &xent_deriv); } } } |
chain/chain-training.cc
//該函數只計算交叉熵正則化項所需的數據,但並不在梯度中應用交叉熵正則化項! const DenominatorGraph &den_graph, const Supervision &supervision, const CuMatrixBase<BaseFloat> &nnet_output, BaseFloat *objf, BaseFloat *l2_term, BaseFloat *weight, CuMatrixBase<BaseFloat> *nnet_output_deriv, CuMatrix<BaseFloat> *xent_output_deriv) {
if (!supervision.e2e_fsts.empty()) { ComputeChainObjfAndDerivE2e(opts, den_graph, supervision, nnet_output, objf, l2_term, weight, nnet_output_deriv, xent_output_deriv); return; }
BaseFloat num_logprob_weighted, den_logprob_weighted; bool ok = true; if (nnet_output_deriv != NULL) nnet_output_deriv->SetZero();
{ // Doing the denominator first helps to reduce the maximum // memory use, as we can set 'xent_deriv' to nonempty after // we've freed the memory in this object. DenominatorComputation denominator(opts, den_graph, supervision.num_sequences, nnet_output); /*
denominator.Forward()的結果爲分母詞圖的後驗機率
*/ den_logprob_weighted = supervision.weight * denominator.Forward(); if (nnet_output_deriv)
//其中負號來自於對分母取log ok = denominator.Backward(-supervision.weight, nnet_output_deriv); }
if (xent_output_deriv != NULL) { // the reason for kStrideEqualNumCols is so that we can share the memory // block with the memory that was used for exp_nnet_output_transposed_ from // chain-denominator.cc, which has just been freed; it also uses the // kStrideEqualNumCols arg (its shape is the transpose of this matrix's // shape). xent_output_deriv->Resize(nnet_output.NumRows(), nnet_output.NumCols(), kSetZero, kStrideEqualNumCols); }
{ /*supervision是一句話完整標註對應的分子詞圖,其中包含每一個音素序列的時間範圍信息
其中
至關於nnet_output
*/ //NumeratorComputation類負責'supervision'(分子)FST的前向後向計算 NumeratorComputation numerator(supervision, nnet_output); // note: supervision.weight is included as a factor in the derivative from // the numerator object, as well as the returned logprob. */ 分子詞圖的後驗機率
這與Kaldi nnet1
不一樣,Kaldi nnet3直接對分子詞圖進行計算 因爲詞圖包含了 狀態分佈(NN)、狀態、音素、字的所有信息。 所以,對詞圖的前向後向計算後,獲得的是後驗機率 */ num_logprob_weighted = numerator.Forward(); //此處,沒法是否進行交叉熵正則化, //序列鑑別性訓練的梯度nnet_output_deriv都不變。 //此時,還並無在梯度中應用交叉熵正則化項! if (xent_output_deriv) {
numerator.Backward(xent_output_deriv); if (nnet_output_deriv)
nnet_output_deriv->AddMat(1.0, *xent_output_deriv); } else if (nnet_output_deriv) {
numerator.Backward(nnet_output_deriv); }
} /*
*/
*objf = num_logprob_weighted - den_logprob_weighted;
*weight = supervision.weight * supervision.num_sequences * supervision.frames_per_sequence; //若梯度爲無窮大/不可用 或 分母計算出錯 if (!((*objf) - (*objf) == 0) || !ok) { // inf or NaN detected, or denominator computation returned false. if (nnet_output_deriv) //將梯度設爲零 nnet_output_deriv->SetZero(); if (xent_output_deriv) //將交叉熵梯度設爲零 xent_output_deriv->SetZero(); BaseFloat default_objf = -10; KALDI_WARN << "Objective function is " << (*objf) << " and denominator computation (if done) returned " << std::boolalpha << ok << ", setting objective function to " << default_objf << " per frame."; //將權重設置爲加權默認權重 *objf = default_objf * *weight; }
// This code helps us see how big the derivatives are, on average, // for different frames of the sequences. As expected, they are // smaller towards the edges of the sequences (due to the penalization // of 'incorrect' pdf-ids. if (GetVerboseLevel() >= 1 && nnet_output_deriv != NULL && RandInt(0, 10) == 0) { int32 tot_frames = nnet_output_deriv->NumRows(), frames_per_sequence = supervision.frames_per_sequence, num_sequences = supervision.num_sequences; CuVector<BaseFloat> row_products(tot_frames); row_products.AddDiagMat2(1.0, *nnet_output_deriv, kNoTrans, 0.0); Vector<BaseFloat> row_products_cpu(row_products); Vector<BaseFloat> row_products_per_frame(frames_per_sequence); for (int32 i = 0; i < tot_frames; i++) row_products_per_frame(i / num_sequences) += row_products_cpu(i); KALDI_LOG << "Derivs per frame are " << row_products_per_frame; }
if (opts.l2_regularize == 0.0) { *l2_term = 0.0; } else { // compute the l2 penalty term and its derivative BaseFloat scale = supervision.weight * opts.l2_regularize; //計算L2正則化項
*l2_term = -0.5 * scale * TraceMatMat(nnet_output, nnet_output, kTrans); if (nnet_output_deriv) // nnet_output_deriv->AddMat(-1.0 * scale, nnet_output); } } |
chain/chain-numerator.cc
//進行前向計算,返回 總對數似然 * supervision_.weight ComputeLookupIndexes(); nnet_logprobs_.Resize(nnet_output_indexes_.Dim(), kUndefined); nnet_output_.Lookup(nnet_output_indexes_, nnet_logprobs_.Data()); const fst::StdVectorFst &fst = supervision_.fst; KALDI_ASSERT(fst.Start() == 0); int32 num_states = fst.NumStates(); log_alpha_.Resize(num_states, kUndefined); log_alpha_.Set(-std::numeric_limits<double>::infinity()); tot_log_prob_ = -std::numeric_limits<double>::infinity();
log_alpha_(0) = 0.0; // note, state zero is the start state, we checked above
const BaseFloat *nnet_logprob_data = nnet_logprobs_.Data(); std::vector<int32>::const_iterator fst_output_indexes_iter = fst_output_indexes_.begin();
double *log_alpha_data = log_alpha_.Data();
for (int32 state = 0; state < num_states; state++) { double this_log_alpha = log_alpha_data[state]; for (fst::ArcIterator<fst::StdVectorFst> aiter(fst, state); !aiter.Done(); aiter.Next(), ++fst_output_indexes_iter) { const fst::StdArc &arc = aiter.Value(); int32 nextstate = arc.nextstate; BaseFloat transition_logprob = -arc.weight.Value(); int32 index = *fst_output_indexes_iter; BaseFloat pseudo_loglike = nnet_logprob_data[index]; double &next_log_alpha = log_alpha_data[nextstate]; next_log_alpha = LogAdd(next_log_alpha, pseudo_loglike + transition_logprob + this_log_alpha); } if (fst.Final(state) != fst::TropicalWeight::Zero()) { BaseFloat final_logprob = -fst.Final(state).Value(); tot_log_prob_ = LogAdd(tot_log_prob_, this_log_alpha + final_logprob); } } KALDI_ASSERT(fst_output_indexes_iter == fst_output_indexes_.end()); return tot_log_prob_ * supervision_.weight; } |
//進行後向計算,計算神經網絡輸出的導數 // 對數似然 * supervision_.weight * deriv_weight //加到nnet_output_deriv上 CuMatrixBase<BaseFloat> *nnet_output_deriv) { //分子詞圖 const fst::StdVectorFst &fst = supervision_.fst; //分子詞圖的狀態數 int32 num_states = fst.NumStates();
log_beta_.Resize(num_states, kUndefined); //神經網絡對數似然導數向量 nnet_logprob_derivs_.Resize(nnet_logprobs_.Dim());
// we'll be counting backwards and moving the 'fst_output_indexes_iter' // pointer back. //'fst_output_indexes'包含監督FST中每一個弧的條目,若是按順序訪問每一個狀態的每一個弧,則得到它們時也是順序的。 fst_output_indexes_的內容是nnet_output_indexes_和nnet_logprobs_的索引。 const int32 *fst_output_indexes_iter = &(fst_output_indexes_[0]) + fst_output_indexes_.size(); //在CPU上的nnet輸出中查找得到的log-probs。此向量與nnet_output_indexes_具備相同的大小。在反向計算中,將被從新用於存儲導數。 const BaseFloat *nnet_logprob_data = nnet_logprobs_.Data(); //tot_log_prob_是前向後向計算中獲得的總僞對數似然 double tot_log_prob = tot_log_prob_;
double *log_beta_data = log_beta_.Data();
const double *log_alpha_data = log_alpha_.Data(); //nnet_logprob_derivs_是關於神經網絡對數似然的導數。能夠理解爲佔有機率 BaseFloat *nnet_logprob_deriv_data = nnet_logprob_derivs_.Data(); //遍歷分子詞圖中的每一個狀態 for (int32 state = num_states - 1; state >= 0; state--) { //與該狀態相連的弧數量 int32 this_num_arcs = fst.NumArcs(state); // on the backward pass we access the fst_output_indexes_ vector in a zigzag // pattern. //fst_output_indexes_iter是前向計算中統計的全部弧的數量 fst_output_indexes_iter -= this_num_arcs; const int32 *this_fst_output_indexes_iter = fst_output_indexes_iter;
double this_log_beta = -fst.Final(state).Value(); double this_log_alpha = log_alpha_data[state]; //遍歷與狀態相連的全部弧 for (fst::ArcIterator<fst::StdVectorFst> aiter(fst, state); !aiter.Done(); aiter.Next(), this_fst_output_indexes_iter++) { const fst::StdArc &arc = aiter.Value();
double next_log_beta = log_beta_data[arc.nextstate];
BaseFloat transition_logprob = -arc.weight.Value(); //t int32 index = *this_fst_output_indexes_iter;
BaseFloat pseudo_loglike = nnet_logprob_data[index]; /*累加:
*/ this_log_beta = LogAdd(this_log_beta, pseudo_loglike + transition_logprob + next_log_beta); //分子的後驗佔用率
BaseFloat occupation_logprob = this_log_alpha + pseudo_loglike + transition_logprob + next_log_beta - tot_log_prob, occupation_prob = exp(occupation_logprob);
nnet_logprob_deriv_data[index] += occupation_prob; } // check for -inf. KALDI_PARANOID_ASSERT(this_log_beta - this_log_beta == 0); log_beta_data[state] = this_log_beta; } KALDI_ASSERT(fst_output_indexes_iter == &(fst_output_indexes_[0]));
int32 start_state = 0; // the fact that the start state is numbered 0 is // implied by other properties of the FST // (epsilon-free-ness and topological sorting, and // connectedness).
double tot_log_prob_backward = log_beta_(start_state);
if (!ApproxEqual(tot_log_prob_backward, tot_log_prob_)) KALDI_WARN << "Disagreement in forward/backward log-probs: " << tot_log_prob_backward << " vs. " << tot_log_prob_;
// copy this data to GPU. CuVector<BaseFloat> nnet_logprob_deriv_cuda; nnet_logprob_deriv_cuda.Swap(&nnet_logprob_derivs_); /*nnet_output_indexes是一個(行,列)索引的列表,咱們須要在nnet_output_中查找前向後向計算。 順序是任意的,可是這個向量中的索引出如今fst_output_indexes中; 而且重要的是每對只出現一次(爲了使導數正確相加)。 (行,列)=(PDFS數,特徵數) matrix-common.h:69 nnet_output_deriv(nnet_output_indexes_[i].first, nnet_output_indexes_[i].second) += supervision_.weight * nnet_logprob_deriv_cuda.Data()[i]; */ nnet_output_deriv->AddElements(supervision_.weight, nnet_output_indexes_, nnet_logprob_deriv_cuda.Data()); } |