viterbi 維特比解碼過程,狀態轉移矩陣

viterbi過程
1.hmm相似。 狀態轉移,發射機率
2.逐次計算每一個序列節點的全部狀態下的機率值,最大機率值對應的index。
3.機率值的計算,上一個節點的機率值*轉移機率+當前機率值。
4.最後取出最大的一個值對應的indexes




難點:    理解viterbi的核心點,在於每一個時間步都保留每個可視狀態,每個可視狀態保留上一個時間步的最大隱狀態轉移,
     每個時間步t記錄上一個最大機率轉移過來的時間步t-1的信息,包括index/機率值累積。
         迭代完時間步,根據最後一個最大累積機率值,逐個往前找便可。 根據index對應的狀態逐個往前找。

應用:     狀態轉移求解最佳轉移路徑。  只要連續時間步,每一個時間步有狀態分佈,先後時間步之間有狀態轉移,就能夠使用viterbi進行最佳狀態轉移計算求解。
         狀態轉移矩陣的做用在於 在每一個狀態轉移機率計算時,和固有的狀態轉移矩陣進行加和,再計算。至關於額外的機率添加。
import numpy as np


def viterbi_decode(score, transition_params):
    """
    保留全部可視狀態下,對seqlen中的每一步的全部可視狀態狀況下的中間狀態求解機率最大值,如此
    :param score:
    :param transition_params:
    :return:
    """
    # score  [seqlen,taglen]  transition_params [taglen,taglen]
    trellis=np.zeros_like(score)
    trellis[0]=score[0]
    backpointers=np.zeros_like(score,dtype=np.int32)

    for t in range(1,len(score)):
        matrix_node=np.expand_dims(trellis[t-1],axis=1)+transition_params  #axis=0 表明發射機率初始狀態
        trellis[t]=score[t]+np.max(matrix_node,axis=0)
        backpointers[t]=np.argmax(matrix_node,axis=0)

    viterbi=[np.argmax(trellis[-1],axis=0)]
    for backpointer in reversed(backpointers[1:]):
        viterbi.append(backpointer[viterbi[-1]])
    viterbi_score = np.max(trellis[-1])
    viterbi.reverse()
    print(trellis)
    return viterbi,viterbi_score
def calculate():
    score = np.array([[1, 2, 3],
              [2, 1, 3],
              [1, 3, 2],
              [3, 2,1]])  # (batch_size, time_step, num_tabs)
    transition = np.array([ [2, 1, 3], [1, 3, 2], [3, 2, 1] ] )# (num_tabs, num_tabs)
    lengths = [len(score[0])] # (batch_size, time_step) # numpy print("[numpy]")
    # np_op = viterbi_decode( score=np.array(score[0]), transition_params=np.array(transition))
    # print(np_op[0])
    # print(np_op[1])
    print("=============") # tensorflow
    # score_t = tf.constant(score, dtype=tf.int64)
    # transition_t = transition, dtype=tf.int64
    tf_op = viterbi_decode( score, transition)
    print('--------------------')
    print(tf_op)

if __name__=='__main__':
    calculate()
// java 版本
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.List;

public class viterbi {

    public static int[]  viterbi_decode(double[][]score,double[][]trans ) {
        //score(16,31) trans(31,31)

        int path[] = new int[score.length];
        double trellis[][] = new double[score.length][score[0].length];
        int backpointers[][] = new int [score.length][score[0].length];
        trellis[0] = score[0];

        for(int t = 1; t<score.length;t++) {
            //  一維數組,31個元素 [-1000,-1000,-1000,.......]
            double h[] = trellis[t - 1];
            //i shape(31 ,1) 31行,1列  [ [-1000][-10000][-1000] ]
            //i = np.expand_dims(trellis[t - 1], 1)
//
//            double expand_dims[][] = new double[trans.length][trans[0].length];  //??
//            for(int j = 0;j<expand_dims[0].length;j++) {
//                expand_dims[j] = h;  //todo
//            }

            //zyy begin
            double expand_h[][]=new double[trans.length][trans[0].length];
            for(int i=0;i<trans.length;i++){
                for(int j=0;j<trans.length;j++) {
                    expand_h[i][j]=h[i];
                }
            }
            double expand_dims[][] = new double[trans.length][trans[0].length];  //??
            for(int j = 0;j<expand_dims[0].length;j++) {
                expand_dims[j] =expand_h[j] ;  //todo
            }
            //zyy_end


            double v[][] = new double[trans.length][trans[0].length];
            for(int i = 0; i < v.length; i++ ) {
                for(int j = 0; j< v[0].length ;j++) {
                    v[i][j] = expand_dims[i][j] + trans[i][j];
                }
            }

            //取每列最大的值 獲得score.length個每列最大值,一維數組
            double max_v[] = new double[trans[0].length];
            int max_v_linepoint[] = new int[trans[0].length];
            for (int j = 0; j < v[0].length; j++) {
                double max_column = v[0][j];
                int line_point = 0;
                for (int i = 0; i < v.length; i++) {
                    if(v[i][j] > max_column) {
                        max_column = v[i][j];
                        line_point = i;
                    }
                }
                max_v[j] = max_column;
                max_v_linepoint[j] = line_point;
            }

            for(int i = 0 ;i < score[0].length; i++ ) {
                trellis[t][i] = score[t][i] + max_v[i];
                backpointers[t][i] = max_v_linepoint[i];
            }

        }

        int viterbi[] = new int[score.length];
//            List<Integer> viterbi = new ArrayList<>();
        double max_trellis = trellis[score.length-1][0];
        for(int j = 0; j< trellis[score.length-1].length ;j++) {
            if(trellis[score.length-1][j] > max_trellis) {
                max_trellis = trellis[score.length-1][j];
//                    viterbi.add(j);
                viterbi[0] = j;
            }
        }

        for(int i=1;i< 1+(backpointers.length)/2;i++){
            int temp[] = backpointers[i];
            backpointers[i] = backpointers[backpointers.length-i];
            backpointers[backpointers.length-i]=temp;
        }

        for(int i = 1; i < backpointers.length; i++ ) {
//                viterbi.add( backpointers[i][viterbi.get(viterbi.size() - 1)]);
            viterbi[i] = backpointers[i][viterbi[i-1]];
        }


        for(int i = 0;i < (viterbi.length)/2; i++){    //把數組的值賦給一個臨時變量
            int temp = viterbi[i];
            viterbi[i] = viterbi[viterbi.length-i-1];
            viterbi[viterbi.length-i-1] = temp;
        }


        return viterbi;
    }

    public static void main(String[] args){
        List<List<Integer>> score=new ArrayList<>();
        ArrayList<Integer> row1=new ArrayList<>();
        row1.add(1);
        row1.add(2);
        row1.add(3);

        ArrayList<Integer> row2=new ArrayList<>();
        row2.add(2);
        row2.add(1);
        row2.add(3);

        ArrayList<Integer> row3=new ArrayList<>();
        row3.add(1);
        row3.add(3);
        row3.add(2);

        ArrayList<Integer> row4=new ArrayList<>();
        row4.add(3);
        row4.add(2);
        row4.add(1);

        score.add(row1);
        score.add(row2);
        score.add(row3);
        score.add(row4);

        List<List<Integer>> trans=new ArrayList<>();
        ArrayList<Integer> row11=new ArrayList<>();
        row11.add(2);
        row11.add(1);
        row11.add(3);

        ArrayList<Integer> row12=new ArrayList<>();
        row12.add(1);
        row12.add(3);
        row12.add(2);

        ArrayList<Integer> row13=new ArrayList<>();
        row13.add(3);
        row13.add(2);
        row13.add(1);

        trans.add(row11);
        trans.add(row12);
        trans.add(row13);



//        double[][] score_double=(double[][]) score.toArray();
//        double[][] trans_double=(double[][]) trans.toArray();
        System.out.println(score);
        System.out.println(trans);

        double[][] score_double=new double[score.size()][score.get(0).size()];
        for(int i=0;i<score.size();i++){
//                score_double[i]=score.get(i);
            for(int j=0;j<score.get(0).size();j++){
                score_double[i][j]=score.get(i).get(j);
            }
        }



        double[][] trans_double=new double[trans.size()][trans.get(0).size()];
        for(int i=0;i<trans.size();i++){
//                score_double[i]=score.get(i);
            for(int j=0;j<trans.get(0).size();j++){
                trans_double[i][j]=trans.get(i).get(j);
            }
        }




        int[] result=viterbi_decode(score_double,trans_double);
        System.out.println("===========****===============");
        System.out.println(result.toString());
    }

}
相關文章
相關標籤/搜索