bzoj3473 字符串

雙倍經驗:bzoj3277數組

題意:給定n個字符串,對於每一個字符串,求有多少個非空子串是其中至少k個字符串的子串。ide

解:有一種毒瘤後綴數組解法...不事後綴自動機吊打後綴樹組!耶~spa

廣義後綴自動機。code

先建出來,而後對於每一個串,跑一遍後綴自動機。咱們試圖把該串的全部子串都打上標記(sum++)。blog

跑到的每一個節點,都表明一個前綴,咱們跳fail就能獲得該前綴的全部後綴。用vis數組避免重複標記。字符串

全部sum >= k的節點,貢獻爲len - len[fail],不然爲0。string

以後對於每一個串再跑後綴自動機,跑到的每一個節點的貢獻爲它到根的貢獻,能夠預處理出來,加起來就是答案。it

  1 #include <cstdio>
  2 #include <algorithm>
  3 #include <cstring>
  4 #include <string>
  5 
  6 const int N = 500010;
  7 
  8 struct Edge {
  9     int nex, v;
 10 }edge[N << 1]; int top;
 11 
 12 int tr[N][26], len[N], cnt[N], sum[N], fail[N], vis[N];
 13 int last = 1, tot = 1, Time, k, e[N];
 14 std::string str[100010];
 15 char s[N];
 16 
 17 inline void add(int x, int y) {
 18     top++;
 19     edge[top].v = y;
 20     edge[top].nex = e[x];
 21     e[x] = top;
 22     return;
 23 }
 24 
 25 inline int split(int p, int f) {
 26     int Q = tr[p][f], nQ = ++tot;
 27     len[nQ] = len[p] + 1;
 28     fail[nQ] = fail[Q];
 29     fail[Q] = nQ;
 30     memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
 31     while(tr[p][f] == Q) {
 32         tr[p][f] = nQ;
 33         p = fail[p];
 34     }
 35     return nQ;
 36 }
 37 
 38 inline int insert(int p, char c) {
 39     int f = c - 'a';
 40     if(tr[p][f]) {
 41         int Q = tr[p][f];
 42         if(len[Q] == len[p] + 1) {
 43             return Q;
 44         }
 45         return split(p, f);
 46     }
 47     int np = ++tot;
 48     len[np] = len[p] + 1;
 49     while(p && !tr[p][f]) {
 50         tr[p][f] = np;
 51         p = fail[p];
 52     }
 53     if(!p) {
 54         fail[np] = 1;
 55     }
 56     else {
 57         int Q = tr[p][f];
 58         if(len[Q] == len[p] + 1) {
 59             fail[np] = Q;
 60         }
 61         else {
 62             fail[np] = split(p, f);
 63         }
 64     }
 65     return np;
 66 }
 67 
 68 void DFS(int x) {
 69     if(vis[x] == Time || !x) {
 70         return;
 71     }
 72     sum[x]++;
 73     vis[x] = Time;
 74     DFS(fail[x]);
 75     return;
 76 }
 77 
 78 void DFS_1(int x) {
 79     cnt[x] = cnt[fail[x]] + (len[x] - len[fail[x]]) * (sum[x] >= k);
 80     //printf("cnt %d = %d + %d * %d \n", x, cnt[fail[x]], len[x] - len[fail[x]], sum[x] >= k);
 81     for(int i = e[x]; i; i = edge[i].nex) {
 82         int y = edge[i].v;
 83         DFS_1(y);
 84     }
 85     return;
 86 }
 87 
 88 int main() {
 89     int n;
 90     scanf("%d%d", &n, &k);
 91     for(int i = 1; i <= n; i++) {
 92         scanf("%s", s);
 93         int t = strlen(s);
 94         last = 1;
 95         for(int j = 0; j < t; j++) {
 96             last = insert(last, s[j]);
 97         }
 98         str[i] = (std::string)(s);
 99     }
100 
101     for(int i = 1; i <= n; i++) {
102         int p = 1, t = str[i].size();
103         Time = i;
104         for(int j = 0; j < t; j++) {
105             int f = str[i][j] - 'a';
106             p = tr[p][f];
107             DFS(p);
108         }
109     }
110 
111     for(int i = 2; i <= tot; i++) {
112         add(fail[i], i);
113         //printf("add %d %d len = %d %d \n", fail[i], i, len[fail[i]], len[i]);
114     }
115     DFS_1(1);
116 
117     for(int i = 1; i <= n; i++) {
118         int p = 1, t = str[i].size(), ans = 0;
119         //printf("i = %d \n", i);
120         for(int j = 0; j < t; j++) {
121             int f = str[i][j] - 'a';
122             p = tr[p][f];
123             ans += cnt[p];
124             //printf(" > p = %d ans += %d \n", p, cnt[p]);
125         }
126         printf("%d ", ans);
127     }
128 
129     return 0;
130 }
AC代碼
相關文章
相關標籤/搜索