hdu5343 後綴自動機+dp

給定兩個串,分別截取字串X和Y,鏈接組成X+Y,求不一樣的X+Y的方案數。ios

對於X+Y,若是重複的部分其實就是從同一個X+Y的某個地方斷開弄成不一樣的X和Y,那麼只要使得X和X+Y匹配得最長就好了。ide

所以,對兩個字符串分別創建後綴自動機A和B,在A中找字串X,當X的末尾不能接某個字符c時,在B中找以c爲開頭的全部字串。spa

注意字串的是n^2個,因此無論怎樣都不能以暴力遍歷自動機的方式來統計,而因爲SAM是DAG,因此其實是在兩個DAG上進行dp。code

 

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#define REP(i,a,b) for(int i=a;i<=b;i++)
#define MS0(a) memset(a,0,sizeof(a))

using namespace std;

typedef unsigned long long ll;
const int maxn=1000100;
const int INF=1e9+10;

char s[maxn],t[maxn];
ll dp1[maxn],dp2[maxn];

struct SAM
{
    int ch[maxn][26];
    int pre[maxn],step[maxn];
    int last,tot;
    void init()
    {
        last=tot=0;
        memset(ch[0],-1,sizeof(ch[0]));
        pre[0]=-1;
        step[0]=0;
    }
    void add(int c)
    {
        c-='a';
        int p=last,np=++tot;
        step[np]=step[p]+1;
        memset(ch[np],-1,sizeof(ch[np]));
        while(~p&&ch[p][c]==-1) ch[p][c]=np,p=pre[p];
        if(p==-1) pre[np]=0;
        else{
            int q=ch[p][c];
            if(step[q]!=step[p]+1){
                int nq=++tot;
                step[nq]=step[p]+1;
                memcpy(ch[nq],ch[q],sizeof(ch[q]));
                pre[nq]=pre[q];
                pre[q]=pre[np]=nq;
                while(~p&&ch[p][c]==q) ch[p][c]=nq,p=pre[p];
            }
            else pre[np]=q;
        }
        last=np;
    }
};SAM A,B;

ll dfs2(int u)
{
    if(u==-1) return 0;
    ll &res=dp2[u];
    if(~res) return res;
    res=1;
    REP(c,0,25) res+=dfs2(B.ch[u][c]);
    return res;
}

ll dfs1(int u)
{
    ll &res=dp1[u];
    if(~res) return res;
    res=1;
    REP(c,0,25){
        if(~A.ch[u][c]) res+=dfs1(A.ch[u][c]);
        else res+=dfs2(B.ch[0][c]);
    }
    return res;
}

void solve()
{
    A.init();B.init();
    int ls=strlen(s),lt=strlen(t);
    REP(i,0,ls-1) A.add(s[i]);
    REP(i,0,lt-1) B.add(t[i]);
    memset(dp1,-1,sizeof(dp1));
    memset(dp2,-1,sizeof(dp2));
    printf("%I64u\n",dfs1(0));
}

int main()
{
    freopen("in.txt","r",stdin);
    int T;cin>>T;
    while(T--){
        scanf("%s%s",s,t);
        solve();
    }
    return 0;
}
View Code
相關文章
相關標籤/搜索