洛谷P1117 優秀的拆分

題意:求一個字符串中有多少形如AABB的子串。ide

解:嗯...我首先極度SB的想了一個後綴自動機套線段樹啓發式合併的作法,想必會TLE。spa

而後跑去看題解,發現實在是妙趣橫生...code

顯然要對每一個位置求出向左有多少個AA,向右有多少個BB。
blog

個人想法是對於每一個前綴,兩兩求lca,若是lca的len大於他們的位置之差,顯然就有一組了。字符串

這時候把貢獻加到其中較長的前綴上。而後反着來一遍就好了。string

怎麼批量求lca和貢獻呢?it

考慮計算每一個點做爲lca時的貢獻,顯然線段樹維護子樹內有哪些前綴。合併的時候好像沒啥好的辦法...可是咱們有啓發式合併!io

每次取出小的線段樹中的全部元素,依次加入大的線段樹中。對於大的線段樹中比它小的一段區間內的元素,咱們要給它本身加上貢獻。對於比它大的一段區間中的元素,要給那些大的元素每一個+1貢獻。咱們就在每次須要插入元素的時候往下推。推到底的時候加貢獻便可。(應該支持吧...)event

比較菜沒寫代碼...感受實現起來毒瘤的緊。ast

而後說正解。

考慮枚舉AA串的長度。

對於一個長爲2len的AA串,若是咱們每隔len放一個點,那麼這樣的串將會且僅會覆蓋兩個連續的點。

對於每兩個連續的點,咱們求它們的最長公共前/後綴長度,分別設爲x,y。

若是x + y >= len的話就是存在這樣的AA串通過這兩點。而後就是個線段樹區間+1

最後遍歷線段樹統計答案便可。

求lcp不就是SAM的fail樹上lca嘛,我會倍增!

Tnlog2n成功T飛...

而後就O(1)lca過了...果真O(1)lca仍是有用的。

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <algorithm>
  4 
  5 typedef long long LL;
  6 const int N = 80010;
  7 
  8 char str[N];
  9 int pos[N], pos2[N], pw[N * 2], n;
 10 
 11 struct SAM {
 12 
 13     struct Edge {
 14         int nex, v;
 15     }edge[N]; int top;
 16     
 17     int tr[N][26], fail[N], len[N], tot, last;
 18     int ST[N * 2][20], pos[N * 2], num, e[N], d[N];
 19     
 20     SAM() {
 21         tot = last = 1;
 22     }
 23 
 24     inline void add(int x, int y) {
 25         top++;
 26         edge[top].v = y;
 27         edge[top].nex = e[x];
 28         e[x] = top;
 29         return;
 30     }
 31 
 32     inline void insert(char c) {
 33         int f = c - 'a';
 34         int p = last, np = ++tot;
 35         last = np;
 36         len[np] = len[p] + 1;
 37         while(p && !tr[p][f]) {
 38             tr[p][f] = np;
 39             p = fail[p];
 40         }
 41         if(!p) {
 42             fail[np] = 1;
 43         }
 44         else {
 45             int Q = tr[p][f];
 46             if(len[Q] == len[p] + 1) {
 47                 fail[np] = Q;
 48             }
 49             else {
 50                 int nQ = ++tot;
 51                 len[nQ] = len[p] + 1;
 52                 fail[nQ] = fail[Q];
 53                 fail[Q] = fail[np] = nQ;
 54                 memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
 55                 while(tr[p][f] == Q) {
 56                     tr[p][f] = nQ;
 57                     p = fail[p];
 58                 }
 59             }
 60         }
 61     }
 62     
 63     void DFS(int x) {
 64         pos[x] = ++num;
 65         ST[num][0] = x;
 66         for(int i = e[x]; i; i = edge[i].nex) {
 67             int y = edge[i].v;
 68             d[y] = d[x] + 1;
 69             DFS(y);
 70             ST[++num][0] = x;
 71         }
 72         return;
 73     }
 74 
 75     inline void prework() {
 76         for(int i = 2; i <= tot; i++) {
 77             add(fail[i], i);
 78         }
 79         d[1] = 1;
 80         DFS(1);
 81         for(int j = 1; j <= pw[num]; j++) {
 82             for(int i = 1; i + (1 << j) - 1 <= num; i++) {
 83                 if(d[ST[i][j - 1]] <= d[ST[i + (1 << (j - 1))][j - 1]]) {
 84                     ST[i][j] = ST[i][j - 1];
 85                 }
 86                 else {
 87                     ST[i][j] = ST[i + (1 << (j  - 1))][j - 1];
 88                 }
 89             }
 90         }
 91         return;
 92     }
 93 
 94     inline int lca(int x, int y) {
 95         x = pos[x];
 96         y = pos[y];
 97         if(x > y) {
 98             std::swap(x, y);
 99         }
100         int t = pw[y - x + 1];
101         if(d[ST[x][t]] <= d[ST[y - (1 << t) + 1][t]]) {
102             return ST[x][t];
103         }
104         return ST[y - (1 << t) + 1][t];
105     }
106 
107     inline void clear() {
108         for(int i = 1; i <= tot; i++) {
109             d[i] = e[i] = 0;
110             for(int f = 0; f < 26; f++) {
111                 tr[i][f] = 0;
112             }
113         }
114         tot = last = 1;
115         top = num = 0;
116         return;
117     }
118 
119     inline int lcp(int x, int y) {
120         return std::min(std::min(len[x], len[y]), len[lca(x, y)]);
121     }
122 
123 }sam, sam2;
124 
125 struct SegmentTree {
126     int tag[N * 2];
127     int f[N];
128     inline void pushdown(int o) {
129         if(!tag[o]) {
130             return;
131         }
132         tag[o << 1] += tag[o];
133         tag[o << 1 | 1] += tag[o];
134         tag[o] = 0;
135         return;
136     }
137 
138     void add(int L, int R, int l, int r, int o) {
139         if(L <= l && r <= R) {
140             tag[o]++;
141             return;
142         }
143         int mid = (l + r) >> 1;
144         pushdown(o);
145         if(L <= mid) {
146             add(L, R, l, mid, o << 1);
147         }
148         if(mid < R) {
149             add(L, R, mid + 1, r, o << 1 | 1);
150         }
151         return;
152     }
153 
154     void solve(int l, int r, int o) {
155         if(l == r) {
156             f[r] = tag[o];
157             return;
158         }
159         pushdown(o);
160         int mid = (l + r) >> 1;
161         solve(l, mid, o << 1);
162         solve(mid + 1, r, o << 1 | 1);
163         return;
164     }
165     void clear(int l, int r, int o) {
166         tag[o] = 0;
167         if(l == r) {
168             return;
169         }
170         int mid = (l + r) >> 1;
171         clear(l, mid, o << 1);
172         clear(mid + 1, r, o << 1 | 1);
173         return;
174     }
175 }seg, seg2;
176 
177 inline void solve() {
178     scanf("%s", str);
179     LL ans = 0;
180     n = strlen(str);
181     for(int i = 0; i < n; i++) {
182         sam.insert(str[i]);
183         sam2.insert(str[n - i - 1]);
184         pos[i] = sam.last;
185         pos2[n - i - 1] = sam2.last;
186     }
187     sam.prework();
188     sam2.prework();
189     //
190     for(int len = 1; (len << 1) < n - 1; len++) {
191         //printf("len = %d \n", len);
192         for(int i = len; i < n; i += len) {
193             // i i-len
194             //printf(" > %d %d \n", i - len, i);
195             int x = std::min(len, sam.lcp(pos[i], pos[i - len]));
196             int y = std::min(len, sam2.lcp(pos2[i], pos2[i - len]));
197             // x + y - len
198             //printf(" > x = %d y = %d \n", x, y);
199             if(x + y > len) {
200                 seg.add(i - len - x + 2, i - len * 2 + y + 1, 1, n, 1);
201                 //printf(" > > > 1 add %d %d \n", i - len - x + 2, i - len * 2 + y + 1);
202                 seg2.add(i + len - x + 1, i + y, 1, n, 1);
203                 //printf(" > > > 2 add %d %d \n", i + len - x + 1, i + y);
204             }
205         }
206     }
207     seg.solve(1, n, 1);
208     seg2.solve(1, n, 1);
209     for(int i = 2; i < n - 1; i++) {
210         ans += 1ll * seg2.f[i] * seg.f[i + 1];
211         //printf("ans += %d * %d \n", seg2.f[i], seg.f[i + 1]);
212     }
213     printf("%lld\n", ans);
214     return;
215 }
216 
217 int main() {
218 
219     for(int i = 2; i < N * 2; i++) {
220         pw[i] = pw[i >> 1] + 1;
221     }
222     int T;
223     scanf("%d", &T);
224     while(T--) {
225         solve();
226         if(T) {
227             sam.clear();
228             sam2.clear();
229             seg.clear(1, n, 1);
230             seg2.clear(1, n, 1);
231         }
232     }
233     return 0;
234 }
AC代碼
相關文章
相關標籤/搜索