bzoj4566 找相同字符

題意:給定兩個字符串,從中各取一個子串使之相同,有多少種取法。容許本質相同。ide

解:創建廣義後綴自動機,對於每一個串,分別統計cnt,以後每一個點的cnt乘起來。記得開long longspa

  1 #include <cstdio>
  2 #include <algorithm>
  3 #include <cstring>
  4 
  5 typedef long long LL;
  6 const int N = 800010;
  7 
  8 struct Edge {
  9     int nex, v;
 10 }edge[N << 1]; int top;
 11 
 12 int tr[N][26], len[N], fail[N], cnt[N][2], vis[N];
 13 int tot = 1, last, turn, e[N];
 14 char s[N], str[N];
 15 
 16 inline void add(int x, int y) {
 17     top++;
 18     edge[top].v = y;
 19     edge[top].nex = e[x];
 20     e[x] = top;
 21     return;
 22 }
 23 
 24 inline int split(int p, int f) {
 25     int Q = tr[p][f], nQ = ++tot;
 26     len[nQ] = len[p] + 1;
 27     fail[nQ] = fail[Q];
 28     fail[Q] = nQ;
 29     memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
 30     while(tr[p][f] == Q) {
 31         tr[p][f] = nQ;
 32         p = fail[p];
 33     }
 34     return nQ;
 35 }
 36 
 37 inline int insert(int p, char c) {
 38     int f = c - 'a';
 39     if(tr[p][f]) {
 40         int Q = tr[p][f];
 41         if(len[Q] == len[p] + 1) {
 42             cnt[Q][turn] = 1;
 43             return Q;
 44         }
 45         int t = split(p, f);
 46         cnt[t][turn] = 1;
 47         return t;
 48     }
 49     int np = ++tot;
 50     len[np] = len[p] + 1;
 51     cnt[np][turn] = 1;
 52     while(p && !tr[p][f]) {
 53         tr[p][f] = np;
 54         p = fail[p];
 55     }
 56     if(!p) {
 57         fail[np] = 1;
 58     }
 59     else {
 60         int Q = tr[p][f];
 61         if(len[Q] == len[p] + 1) {
 62             fail[np] = Q;
 63         }
 64         else {
 65             fail[np] = split(p, f);
 66         }
 67     }
 68     return np;
 69 }
 70 
 71 void DFS(int x) {
 72     for(int i = e[x]; i; i = edge[i].nex) {
 73         int y = edge[i].v;
 74         DFS(y);
 75         cnt[x][0] += cnt[y][0];
 76         cnt[x][1] += cnt[y][1];
 77     }
 78     return;
 79 }
 80 
 81 int main() {
 82     scanf("%s%s", s, str);
 83     int n = strlen(s), last = 1;
 84     for(int i = 0; i < n; i++) {
 85         last = insert(last, s[i]);
 86     }
 87     n = strlen(str);
 88     last = turn = 1;
 89     for(int i = 0; i < n; i++) {
 90         last = insert(last, str[i]);
 91     }
 92     for(int i = 2; i <= tot; i++) {
 93         add(fail[i], i);
 94     }
 95     DFS(1);
 96     int p = 1;
 97     LL ans = 0;
 98     for(int i = 2; i <= tot; i++) {
 99         ans += 1ll * cnt[i][0] * cnt[i][1] * (len[i] - len[fail[i]]);
100     }
101 
102     printf("%lld", ans);
103     return 0;
104 }
AC代碼
相關文章
相關標籤/搜索