kaldi chain模型的序列鑑別性訓練代碼分析

chainbin/nnet3-chain-train.cc node

int main(int argc, char *argv[]) { 網絡

... app

Nnet nnet; ide

ReadKaldiObject(nnet_rxfilename, &nnet); 函數

bool ok; this

{ spa

fst::StdVectorFst den_fst; 3d

ReadFstKaldi(den_fst_rxfilename, &den_fst); 指針

 

//NnetChainTrainer讀取訓練參數opts、分母詞圖den_fst、神經網絡nnet code

NnetChainTrainer trainer(opts, den_fst, &nnet);

//SequentialNnetChainExampleReader以語句爲單位讀取樣本

SequentialNnetChainExampleReader example_reader(examples_rspecifier);

for (; !example_reader.Done(); example_reader.Next())

//以句爲單位進行訓練

trainer.Train(example_reader.Value());

ok = trainer.PrintTotalStats();

}n

...

WriteKaldiObject(nnet, nnet_wxfilename, binary_write);

...

}

nnet3/nnet-chain-training.cc

void NnetChainTrainer::Train(const NnetChainExample &chain_eg) {

bool need_model_derivative = true;

const NnetTrainerOptions &nnet_config = opts_.nnet_config;

bool use_xent_regularization = (opts_.chain_config.xent_regularize != 0.0);

ComputationRequest request;

//This function takes a NnetChainExample and produces a ComputationRequest.

GetChainComputationRequest(*nnet_, chain_eg, need_model_derivative,

nnet_config.store_component_stats,

use_xent_regularization, need_model_derivative,

&request);

//進行編譯,返回到結果的常量指針。

//返回的常量指針由CachingOptimizingCompiler NnetChainTrainer::compiler_全部

//若是編譯失敗,用std::shared_ptr<const NnetComputation>接收返回值

std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);

 

   

if (nnet_config.backstitch_training_scale > 0.0 && num_minibatches_processed_

% nnet_config.backstitch_training_interval ==

srand_seed_ % nnet_config.backstitch_training_interval) {

// backstitch training is incompatible with momentum > 0

KALDI_ASSERT(nnet_config.momentum == 0.0);

FreezeNaturalGradient(true, delta_nnet_);

bool is_backstitch_step1 = true;

srand(srand_seed_ + num_minibatches_processed_);

ResetGenerators(nnet_);

TrainInternalBackstitch(chain_eg, *computation, is_backstitch_step1);

FreezeNaturalGradient(false, delta_nnet_); // un-freeze natural gradient

is_backstitch_step1 = false;

srand(srand_seed_ + num_minibatches_processed_);

ResetGenerators(nnet_);

TrainInternalBackstitch(chain_eg, *computation, is_backstitch_step1);

} else { // conventional training

TrainInternal(chain_eg, *computation);

}

   

num_minibatches_processed_++;

}

   

   

void NnetChainTrainer::TrainInternal(const NnetChainExample &eg,

const NnetComputation &computation) {

//NnetComputer類負責執行"computation"對象描述的計算。

//以如下順序調用:

構造函數

AcceptInput()【或AcceptInputs()

Run()

GetOutput()

AcceptOutputDeriv()【若可用】

Run()【若是須要反向計算】

GetInputDeriv()【若可用】:

NnetComputer computer(nnet_config.compute_config, computation,

nnet_, delta_nnet_);

computer.AcceptInputs(*nnet_, eg.inputs);

//前向傳播,計算

computer.Run();

//該函數調用了GetOutput()

this->ProcessOutputs(false, eg, &computer);

//反向傳播,計算權重更新量delta_nnet_

computer.Run();

//根據L2正則化項,修改權重更新量delta_nnet_

ApplyL2Regularization(*nnet_,

GetNumNvalues(eg.inputs, false) *

nnet_config.l2_regularize_factor,

delta_nnet_);

//根據權重更新量delta_nnet_,更新神經網絡,上限爲nnet_config.max_param_change

bool success =

UpdateNnetWithMaxChange(*delta_nnet_,

nnet_config.max_param_change,

1.0,

1.0 - nnet_config.momentum,

nnet_,

&num_max_change_per_component_applied_,

&num_max_change_global_applied_);

  

   

   

void NnetChainTrainer::ProcessOutputs(bool is_backstitch_step2,

const NnetChainExample &eg,

NnetComputer *computer) {

// normally the eg will have just one output named 'output', but

// we don't assume this.

// In backstitch training, the output-name with the "_backstitch" suffix is

// the one computed after the first, backward step of backstitch.

const std::string suffix = (is_backstitch_step2 ? "_backstitch" : "");

std::vector<NnetChainSupervision>::const_iterator iter = eg.outputs.begin(),

end = eg.outputs.end();

for (; iter != end; ++iter) {

//檢查每一個樣本的標籤是否與網絡相匹配

const NnetChainSupervision &sup = *iter;

int32 node_index = nnet_->GetNodeIndex(sup.name);

if (node_index < 0 ||

!nnet_->IsOutputNode(node_index))

KALDI_ERR << "Network has no output named " << sup.name;

   

const CuMatrixBase<BaseFloat> &nnet_output = computer->GetOutput(sup.name);

CuMatrix<BaseFloat> nnet_output_deriv(nnet_output.NumRows(),

nnet_output.NumCols(),

kUndefined);

//是否進行交叉熵正則化

bool use_xent = (opts_.chain_config.xent_regularize != 0.0);

//從名爲"output-xent"的component-node獲取交叉熵的目標函數值

std::string xent_name = sup.name + "-xent"; // typically "output-xent".

CuMatrix<BaseFloat> xent_deriv;

//tot_objf,目標函數值,未包含L2正則化項,未包含交叉熵正則化項

//tot_l2_termL2正則化項

//tot_weightL2正則化項權重

BaseFloat tot_objf, tot_l2_term, tot_weight;

//根據預測和標籤計算目標函數值及其梯度,計算交叉熵正則化項及其權重

   

//幀平滑-序列鑑別性準則

ComputeChainObjfAndDeriv(opts_.chain_config, den_graph_,

sup.supervision, nnet_output,

&tot_objf, &tot_l2_term, &tot_weight,

&nnet_output_deriv,

(use_xent ? &xent_deriv : NULL));

   

//更新梯度統計量

if (use_xent) {

// 從神經網絡中獲取交叉熵output-node的輸出

const CuMatrixBase<BaseFloat> &xent_output = computer->GetOutput(

xent_name);

/* 此時,xent_derivMMI準則函數的分子後驗/分子錯誤信號。

/*

BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans);

objf_info_[xent_name + suffix].UpdateStats(xent_name + suffix,

opts_.nnet_config.print_interval,

num_minibatches_processed_,

tot_weight, xent_objf);

}

//乘以梯度權重

if (opts_.apply_deriv_weights && sup.deriv_weights.Dim() != 0) {

CuVector<BaseFloat> cu_deriv_weights(sup.deriv_weights);

nnet_output_deriv.MulRowsVec(cu_deriv_weights);

if (use_xent)

//xent_deriv=diag(cu_deriv_weights)*xent_deriv

//cu_deriv_weights[i]xent_deriv的第i行進行縮放

xent_deriv.MulRowsVec(cu_deriv_weights);

}

//計算器接收梯度

computer->AcceptInput(sup.name, &nnet_output_deriv);

 

objf_info_[sup.name + suffix].UpdateStats(sup.name + suffix,

opts_.nnet_config.print_interval,

num_minibatches_processed_,

tot_weight, tot_objf, tot_l2_term);

   

if (use_xent) {

//以交叉熵正則化因子進行縮放

xent_deriv.Scale(opts_.chain_config.xent_regularize);

//接收交叉熵正則化的梯度

computer->AcceptInput(xent_name, &xent_deriv);

}

}

}

chain/chain-training.cc

//該函數只計算交叉熵正則化項所需的數據,但並不在梯度中應用交叉熵正則化項!
void ComputeChainObjfAndDeriv(const ChainTrainingOptions &opts,

const DenominatorGraph &den_graph,

const Supervision &supervision,

const CuMatrixBase<BaseFloat> &nnet_output,

BaseFloat *objf,

BaseFloat *l2_term,

BaseFloat *weight,

CuMatrixBase<BaseFloat> *nnet_output_deriv,

CuMatrix<BaseFloat> *xent_output_deriv) {

   

if (!supervision.e2e_fsts.empty()) {

ComputeChainObjfAndDerivE2e(opts, den_graph, supervision,

nnet_output, objf, l2_term,

weight, nnet_output_deriv, xent_output_deriv);

return;

}

   

BaseFloat num_logprob_weighted, den_logprob_weighted;

bool ok = true;

if (nnet_output_deriv != NULL)

nnet_output_deriv->SetZero();

   

{ // Doing the denominator first helps to reduce the maximum

// memory use, as we can set 'xent_deriv' to nonempty after

// we've freed the memory in this object.

DenominatorComputation denominator(opts, den_graph,

supervision.num_sequences,

nnet_output);

/*

denominator.Forward()的結果爲分母詞圖的後驗機率

*/

den_logprob_weighted = supervision.weight * denominator.Forward();

if (nnet_output_deriv)

//其中負號來自於對分母取log

ok = denominator.Backward(-supervision.weight,

nnet_output_deriv);

}

   

if (xent_output_deriv != NULL) {

// the reason for kStrideEqualNumCols is so that we can share the memory

// block with the memory that was used for exp_nnet_output_transposed_ from

// chain-denominator.cc, which has just been freed; it also uses the

// kStrideEqualNumCols arg (its shape is the transpose of this matrix's

// shape).

xent_output_deriv->Resize(nnet_output.NumRows(), nnet_output.NumCols(),

kSetZero, kStrideEqualNumCols);

}

   

{

/*supervision是一句話完整標註對應的分子詞圖,其中包含每一個音素序列的時間範圍信息

其中

至關於nnet_output

*/

//NumeratorComputation類負責'supervision'(分子)FST的前向後向計算

NumeratorComputation numerator(supervision, nnet_output);

// note: supervision.weight is included as a factor in the derivative from

// the numerator object, as well as the returned logprob.

*/

分子詞圖的後驗機率

這與Kaldi nnet1

爲神經網絡後驗機率

不一樣,Kaldi nnet3直接對分子詞圖進行計算

因爲詞圖包含了

狀態分佈(NN)、狀態、音素、字的所有信息。

所以,對詞圖的前向後向計算後,獲得的是後驗機率

*/

num_logprob_weighted = numerator.Forward();

//此處,沒法是否進行交叉熵正則化,

//序列鑑別性訓練的梯度nnet_output_deriv都不變。

//此時,還並無在梯度中應用交叉熵正則化項!

if (xent_output_deriv)

{

numerator.Backward(xent_output_deriv);

if (nnet_output_deriv)

D維梯度向量

nnet_output_deriv->AddMat(1.0, *xent_output_deriv);

}

else if (nnet_output_deriv)

{

D維梯度向量

   

numerator.Backward(nnet_output_deriv);

}

   

   

}

/*

*/

   

*objf = num_logprob_weighted - den_logprob_weighted;

   

*weight = supervision.weight * supervision.num_sequences *

supervision.frames_per_sequence;

//若梯度爲無窮大/不可用 分母計算出錯

if (!((*objf) - (*objf) == 0) || !ok) {

// inf or NaN detected, or denominator computation returned false.

if (nnet_output_deriv)

//將梯度設爲零

nnet_output_deriv->SetZero();

if (xent_output_deriv)

//將交叉熵梯度設爲零

xent_output_deriv->SetZero();

BaseFloat default_objf = -10;

KALDI_WARN << "Objective function is " << (*objf)

<< " and denominator computation (if done) returned "

<< std::boolalpha << ok

<< ", setting objective function to " << default_objf

<< " per frame.";

//將權重設置爲加權默認權重

*objf = default_objf * *weight;

}

   

// This code helps us see how big the derivatives are, on average,

// for different frames of the sequences. As expected, they are

// smaller towards the edges of the sequences (due to the penalization

// of 'incorrect' pdf-ids.

if (GetVerboseLevel() >= 1 && nnet_output_deriv != NULL && RandInt(0, 10) == 0) {

int32 tot_frames = nnet_output_deriv->NumRows(),

frames_per_sequence = supervision.frames_per_sequence,

num_sequences = supervision.num_sequences;

CuVector<BaseFloat> row_products(tot_frames);

row_products.AddDiagMat2(1.0, *nnet_output_deriv, kNoTrans, 0.0);

Vector<BaseFloat> row_products_cpu(row_products);

Vector<BaseFloat> row_products_per_frame(frames_per_sequence);

for (int32 i = 0; i < tot_frames; i++)

row_products_per_frame(i / num_sequences) += row_products_cpu(i);

KALDI_LOG << "Derivs per frame are " << row_products_per_frame;

}

   

if (opts.l2_regularize == 0.0) {

*l2_term = 0.0;

} else {

// compute the l2 penalty term and its derivative

BaseFloat scale = supervision.weight * opts.l2_regularize;

//計算L2正則化項

*l2_term = -0.5 * scale * TraceMatMat(nnet_output, nnet_output, kTrans);

if (nnet_output_deriv)

//

nnet_output_deriv->AddMat(-1.0 * scale, nnet_output);

}

}

   

chain/chain-numerator.cc

//進行前向計算,返回 總對數似然 * supervision_.weight
BaseFloat NumeratorComputation::Forward() {

ComputeLookupIndexes();

nnet_logprobs_.Resize(nnet_output_indexes_.Dim(), kUndefined);

nnet_output_.Lookup(nnet_output_indexes_, nnet_logprobs_.Data());

const fst::StdVectorFst &fst = supervision_.fst;

KALDI_ASSERT(fst.Start() == 0);

int32 num_states = fst.NumStates();

log_alpha_.Resize(num_states, kUndefined);

log_alpha_.Set(-std::numeric_limits<double>::infinity());

tot_log_prob_ = -std::numeric_limits<double>::infinity();

   

log_alpha_(0) = 0.0; // note, state zero is the start state, we checked above

   

const BaseFloat *nnet_logprob_data = nnet_logprobs_.Data();

std::vector<int32>::const_iterator fst_output_indexes_iter =

fst_output_indexes_.begin();

   

double *log_alpha_data = log_alpha_.Data();

   

for (int32 state = 0; state < num_states; state++) {

double this_log_alpha = log_alpha_data[state];

for (fst::ArcIterator<fst::StdVectorFst> aiter(fst, state); !aiter.Done();

aiter.Next(), ++fst_output_indexes_iter) {

const fst::StdArc &arc = aiter.Value();

int32 nextstate = arc.nextstate;

BaseFloat transition_logprob = -arc.weight.Value();

int32 index = *fst_output_indexes_iter;

BaseFloat pseudo_loglike = nnet_logprob_data[index];

double &next_log_alpha = log_alpha_data[nextstate];

next_log_alpha = LogAdd(next_log_alpha, pseudo_loglike +

transition_logprob + this_log_alpha);

}

if (fst.Final(state) != fst::TropicalWeight::Zero()) {

BaseFloat final_logprob = -fst.Final(state).Value();

tot_log_prob_ = LogAdd(tot_log_prob_,

this_log_alpha + final_logprob);

}

}

KALDI_ASSERT(fst_output_indexes_iter ==

fst_output_indexes_.end());

return tot_log_prob_ * supervision_.weight;

}

   

   

//進行後向計算,計算神經網絡輸出的導數

// 對數似然 * supervision_.weight * deriv_weight

//加到nnet_output_deriv
void NumeratorComputation::Backward(

CuMatrixBase<BaseFloat> *nnet_output_deriv) {

//分子詞圖

const fst::StdVectorFst &fst = supervision_.fst;

//分子詞圖的狀態數

int32 num_states = fst.NumStates();

log_beta_.Resize(num_states, kUndefined);

//神經網絡對數似然導數向量

nnet_logprob_derivs_.Resize(nnet_logprobs_.Dim());

   

// we'll be counting backwards and moving the 'fst_output_indexes_iter'

// pointer back.

//'fst_output_indexes'包含監督FST中每一個弧的條目,若是按順序訪問每一個狀態的每一個弧,則得到它們時也是順序的。 fst_output_indexes_的內容是nnet_output_indexes_nnet_logprobs_的索引。

const int32 *fst_output_indexes_iter = &(fst_output_indexes_[0]) +

fst_output_indexes_.size();

//CPU上的nnet輸出中查找得到的log-probs。此向量與nnet_output_indexes_具備相同的大小。在反向計算中,將被從新用於存儲導數。

const BaseFloat *nnet_logprob_data = nnet_logprobs_.Data();

//tot_log_prob_是前向後向計算中獲得的總僞對數似然

double tot_log_prob = tot_log_prob_;

double *log_beta_data = log_beta_.Data();

const double *log_alpha_data = log_alpha_.Data();

//nnet_logprob_derivs_是關於神經網絡對數似然的導數。能夠理解爲佔有機率

BaseFloat *nnet_logprob_deriv_data = nnet_logprob_derivs_.Data();

//遍歷分子詞圖中的每一個狀態

for (int32 state = num_states - 1; state >= 0; state--) {

//與該狀態相連的弧數量

int32 this_num_arcs = fst.NumArcs(state);

// on the backward pass we access the fst_output_indexes_ vector in a zigzag

// pattern.

//fst_output_indexes_iter是前向計算中統計的全部弧的數量

fst_output_indexes_iter -= this_num_arcs;

const int32 *this_fst_output_indexes_iter = fst_output_indexes_iter;

double this_log_beta = -fst.Final(state).Value();

double this_log_alpha = log_alpha_data[state];

//遍歷與狀態相連的全部弧

for (fst::ArcIterator<fst::StdVectorFst> aiter(fst, state); !aiter.Done();

aiter.Next(), this_fst_output_indexes_iter++) {

const fst::StdArc &arc = aiter.Value();

double next_log_beta = log_beta_data[arc.nextstate];

BaseFloat transition_logprob = -arc.weight.Value();

//t

int32 index = *this_fst_output_indexes_iter;

BaseFloat pseudo_loglike = nnet_logprob_data[index];

/*累加:

 

*/

this_log_beta = LogAdd(this_log_beta, pseudo_loglike +

transition_logprob + next_log_beta);

//分子的後驗佔用率

BaseFloat occupation_logprob = this_log_alpha + pseudo_loglike +

transition_logprob + next_log_beta - tot_log_prob,

occupation_prob = exp(occupation_logprob);

nnet_logprob_deriv_data[index] += occupation_prob;

}

// check for -inf.

KALDI_PARANOID_ASSERT(this_log_beta - this_log_beta == 0);

log_beta_data[state] = this_log_beta;

}

KALDI_ASSERT(fst_output_indexes_iter == &(fst_output_indexes_[0]));

   

int32 start_state = 0; // the fact that the start state is numbered 0 is

// implied by other properties of the FST

// (epsilon-free-ness and topological sorting, and

// connectedness).

double tot_log_prob_backward = log_beta_(start_state);

if (!ApproxEqual(tot_log_prob_backward, tot_log_prob_))

KALDI_WARN << "Disagreement in forward/backward log-probs: "

<< tot_log_prob_backward << " vs. " << tot_log_prob_;

   

// copy this data to GPU.

CuVector<BaseFloat> nnet_logprob_deriv_cuda;

nnet_logprob_deriv_cuda.Swap(&nnet_logprob_derivs_);

/*nnet_output_indexes是一個(行,列)索引的列表,咱們須要在nnet_output_中查找前向後向計算。 順序是任意的,可是這個向量中的索引出如今fst_output_indexes; 而且重要的是每對只出現一次(爲了使導數正確相加)。

(行,列)=PDFS數,特徵數)

matrix-common.h:69

nnet_output_deriv(nnet_output_indexes_[i].first, nnet_output_indexes_[i].second) +=

supervision_.weight * nnet_logprob_deriv_cuda.Data()[i];

*/

nnet_output_deriv->AddElements(supervision_.weight, nnet_output_indexes_,

nnet_logprob_deriv_cuda.Data());

}

相關文章
相關標籤/搜索