【算法】KMP算法

簡介

KMP算法由 Knuth-Morris-Pratt 三位科學家提出,可用於在一個 文本串 中尋找某 模式串 存在的位置。
本算法能夠有效下降在一個 文本串 中尋找某 模式串 過程的時間複雜度。(若是採起樸素的想法則複雜度是 \(O(MN)\)c++

這裏樸素的想法指的是枚舉 文本串 的起點,而後讓 模式串 從第一位開始一個個地檢查是否配對,若是不配對則繼續枚舉起點。算法

前置知識

真前綴
指字符串左部的任意子串(不包含自身),如 abcde 中的 a,ab,abc,abcd 都是真前綴但 abcde 不是。數組

真後綴
指字符串右部的任意子串(不包含自身),如 abcde 中的 e,de,cde,bcde 都是真後綴但 abcde 不是。函數

前綴函數
一個字符串中最長的、相等的真前綴與真後綴的長度, 如AABBAAA對應的前綴函數值是 \(2\)ui

原理

注意:在分析的時候,咱們規定字符串的下標從 \(1\) 開始。spa

開始:
咱們記掃描模式串的指針爲j,而掃描文本串的指針爲i,假設一開始i,j都在起點,而後讓它們一直下去直到徹底匹配或者失配,好比:指針

j
ABCD

i
ABCDEFG

而後code

j
ABCD

 i
ABCDEFG

最後在此完成了一次匹配,相似地若是ABCD改成ABCC則在此失配。對象

j
ABCD

   i
ABCDEFG

i,j運做模式如上。ci



KMP算法就是,當模式串和文本串失配的時候,j指針從真後綴的末尾跳到真前綴的末尾,而後從真前綴後一位開始繼續匹配。(從而起到減小配對次數,這即是KMP算法的核心原理)

結合例子解釋:

模式串: \(AABBAAA\)

文本串: \(AABBAABBAAA\)

j指針在最後一個A處失配。

j
AABBAAA
      i
AABBAABBAAA

由於此時 以j爲尾的前綴 所對應的前綴函數值是 \(2\) ,因此 j指針 跳到這裏:

j
AABBAAA
      i
AABBAABBAAA

而後從下一位開始繼續配對:

j
AABBAAA
      i
AABBAABBAAA

最後

j
AABBAAA
          i
AABBAABBAAA

能夠看出,KMP可以有效減小配對次數。

實現

咱們記模式串p文本串s

從上面的模擬中,咱們發現須要預處理出一個數組(記之爲next[]),它儲存模式串中前綴對應的前綴函數\(\pi()\),如對於字符串ABCABC

\(\pi(0)=0\) (由於什麼都沒有)
\(\pi(1)=0\)A甚至沒有真前綴真後綴
\(\pi(2)=0\)AB
\(\pi(3)=0\)ABC
\(\pi(4)=1\)ABCA
\(\pi(5)=2\)ABCAB
\(\pi(6)=3\)ABCABC

一樣地,咱們發現若是用暴力樸素的想法來統計複雜度是 O(N^2) 很差,因而採用相似於上面的方法,只不過模式串配對的對象是本身罷了。

能夠結合代碼理解,並注意舉例,嘗試在紙上模擬這個過程。

for(int i=2,j=0;i<=lenp;i++){
        while(j && p[j+1]!=p[i]) j=next_[j]; // 若是j指向元素的下一個元素會和當前配對位置失配,則j跳回去
        if(p[j+1]==p[i]) j++; //若是可以配對上,j++
        next_[i]=j; //記錄當前位置的前綴函數π
}

完整代碼:

#include<bits/stdc++.h>
using namespace std;

const int N=1e6+5;
char p[N],s[N];
int next_[N];

int main(){
    cin>>s+1>>p+1;

    int lenp=strlen(p+1),lens=strlen(s+1);
    // build next array
    for(int i=2,j=0;i<=lenp;i++){
        while(j && p[j+1]!=p[i]) j=next_[j]; // 若是j指向元素的下一個元素會和當前配對位置失配,則j跳回去
        if(p[j+1]==p[i]) j++; //若是可以配對上,j++
        next_[i]=j; //記錄當前位置的前綴函數π
    }

    for(int i=1,j=0;i<=lens;i++){
        while(j && p[j+1]!=s[i]) j=next_[j];
        if(p[j+1]==s[i]) j++;

        // if match
        if(j==lenp){
            j=next_[j];
            cout<<i-lenp+1<<endl;
        }
    }

    for(int i=1;i<=lenp;i++) cout<<next_[i]<<' ';
    cout<<endl;

    return 0;
}

複雜度

\(O(N+M)\)

相關文章
相關標籤/搜索