題意:求一個字符串中有多少形如AABB的子串。ide
解:嗯...我首先極度SB的想了一個後綴自動機套線段樹啓發式合併的作法,想必會TLE。spa
而後跑去看題解,發現實在是妙趣橫生...code
顯然要對每一個位置求出向左有多少個AA,向右有多少個BB。
blog
個人想法是對於每一個前綴,兩兩求lca,若是lca的len大於他們的位置之差,顯然就有一組了。字符串
這時候把貢獻加到其中較長的前綴上。而後反着來一遍就好了。string
怎麼批量求lca和貢獻呢?it
考慮計算每一個點做爲lca時的貢獻,顯然線段樹維護子樹內有哪些前綴。合併的時候好像沒啥好的辦法...可是咱們有啓發式合併!io
每次取出小的線段樹中的全部元素,依次加入大的線段樹中。對於大的線段樹中比它小的一段區間內的元素,咱們要給它本身加上貢獻。對於比它大的一段區間中的元素,要給那些大的元素每一個+1貢獻。咱們就在每次須要插入元素的時候往下推。推到底的時候加貢獻便可。(應該支持吧...)event
比較菜沒寫代碼...感受實現起來毒瘤的緊。ast
而後說正解。
考慮枚舉AA串的長度。
對於一個長爲2len的AA串,若是咱們每隔len放一個點,那麼這樣的串將會且僅會覆蓋兩個連續的點。
對於每兩個連續的點,咱們求它們的最長公共前/後綴長度,分別設爲x,y。
若是x + y >= len的話就是存在這樣的AA串通過這兩點。而後就是個線段樹區間+1
最後遍歷線段樹統計答案便可。
求lcp不就是SAM的fail樹上lca嘛,我會倍增!
Tnlog2n成功T飛...
而後就O(1)lca過了...果真O(1)lca仍是有用的。
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 5 typedef long long LL; 6 const int N = 80010; 7 8 char str[N]; 9 int pos[N], pos2[N], pw[N * 2], n; 10 11 struct SAM { 12 13 struct Edge { 14 int nex, v; 15 }edge[N]; int top; 16 17 int tr[N][26], fail[N], len[N], tot, last; 18 int ST[N * 2][20], pos[N * 2], num, e[N], d[N]; 19 20 SAM() { 21 tot = last = 1; 22 } 23 24 inline void add(int x, int y) { 25 top++; 26 edge[top].v = y; 27 edge[top].nex = e[x]; 28 e[x] = top; 29 return; 30 } 31 32 inline void insert(char c) { 33 int f = c - 'a'; 34 int p = last, np = ++tot; 35 last = np; 36 len[np] = len[p] + 1; 37 while(p && !tr[p][f]) { 38 tr[p][f] = np; 39 p = fail[p]; 40 } 41 if(!p) { 42 fail[np] = 1; 43 } 44 else { 45 int Q = tr[p][f]; 46 if(len[Q] == len[p] + 1) { 47 fail[np] = Q; 48 } 49 else { 50 int nQ = ++tot; 51 len[nQ] = len[p] + 1; 52 fail[nQ] = fail[Q]; 53 fail[Q] = fail[np] = nQ; 54 memcpy(tr[nQ], tr[Q], sizeof(tr[Q])); 55 while(tr[p][f] == Q) { 56 tr[p][f] = nQ; 57 p = fail[p]; 58 } 59 } 60 } 61 } 62 63 void DFS(int x) { 64 pos[x] = ++num; 65 ST[num][0] = x; 66 for(int i = e[x]; i; i = edge[i].nex) { 67 int y = edge[i].v; 68 d[y] = d[x] + 1; 69 DFS(y); 70 ST[++num][0] = x; 71 } 72 return; 73 } 74 75 inline void prework() { 76 for(int i = 2; i <= tot; i++) { 77 add(fail[i], i); 78 } 79 d[1] = 1; 80 DFS(1); 81 for(int j = 1; j <= pw[num]; j++) { 82 for(int i = 1; i + (1 << j) - 1 <= num; i++) { 83 if(d[ST[i][j - 1]] <= d[ST[i + (1 << (j - 1))][j - 1]]) { 84 ST[i][j] = ST[i][j - 1]; 85 } 86 else { 87 ST[i][j] = ST[i + (1 << (j - 1))][j - 1]; 88 } 89 } 90 } 91 return; 92 } 93 94 inline int lca(int x, int y) { 95 x = pos[x]; 96 y = pos[y]; 97 if(x > y) { 98 std::swap(x, y); 99 } 100 int t = pw[y - x + 1]; 101 if(d[ST[x][t]] <= d[ST[y - (1 << t) + 1][t]]) { 102 return ST[x][t]; 103 } 104 return ST[y - (1 << t) + 1][t]; 105 } 106 107 inline void clear() { 108 for(int i = 1; i <= tot; i++) { 109 d[i] = e[i] = 0; 110 for(int f = 0; f < 26; f++) { 111 tr[i][f] = 0; 112 } 113 } 114 tot = last = 1; 115 top = num = 0; 116 return; 117 } 118 119 inline int lcp(int x, int y) { 120 return std::min(std::min(len[x], len[y]), len[lca(x, y)]); 121 } 122 123 }sam, sam2; 124 125 struct SegmentTree { 126 int tag[N * 2]; 127 int f[N]; 128 inline void pushdown(int o) { 129 if(!tag[o]) { 130 return; 131 } 132 tag[o << 1] += tag[o]; 133 tag[o << 1 | 1] += tag[o]; 134 tag[o] = 0; 135 return; 136 } 137 138 void add(int L, int R, int l, int r, int o) { 139 if(L <= l && r <= R) { 140 tag[o]++; 141 return; 142 } 143 int mid = (l + r) >> 1; 144 pushdown(o); 145 if(L <= mid) { 146 add(L, R, l, mid, o << 1); 147 } 148 if(mid < R) { 149 add(L, R, mid + 1, r, o << 1 | 1); 150 } 151 return; 152 } 153 154 void solve(int l, int r, int o) { 155 if(l == r) { 156 f[r] = tag[o]; 157 return; 158 } 159 pushdown(o); 160 int mid = (l + r) >> 1; 161 solve(l, mid, o << 1); 162 solve(mid + 1, r, o << 1 | 1); 163 return; 164 } 165 void clear(int l, int r, int o) { 166 tag[o] = 0; 167 if(l == r) { 168 return; 169 } 170 int mid = (l + r) >> 1; 171 clear(l, mid, o << 1); 172 clear(mid + 1, r, o << 1 | 1); 173 return; 174 } 175 }seg, seg2; 176 177 inline void solve() { 178 scanf("%s", str); 179 LL ans = 0; 180 n = strlen(str); 181 for(int i = 0; i < n; i++) { 182 sam.insert(str[i]); 183 sam2.insert(str[n - i - 1]); 184 pos[i] = sam.last; 185 pos2[n - i - 1] = sam2.last; 186 } 187 sam.prework(); 188 sam2.prework(); 189 // 190 for(int len = 1; (len << 1) < n - 1; len++) { 191 //printf("len = %d \n", len); 192 for(int i = len; i < n; i += len) { 193 // i i-len 194 //printf(" > %d %d \n", i - len, i); 195 int x = std::min(len, sam.lcp(pos[i], pos[i - len])); 196 int y = std::min(len, sam2.lcp(pos2[i], pos2[i - len])); 197 // x + y - len 198 //printf(" > x = %d y = %d \n", x, y); 199 if(x + y > len) { 200 seg.add(i - len - x + 2, i - len * 2 + y + 1, 1, n, 1); 201 //printf(" > > > 1 add %d %d \n", i - len - x + 2, i - len * 2 + y + 1); 202 seg2.add(i + len - x + 1, i + y, 1, n, 1); 203 //printf(" > > > 2 add %d %d \n", i + len - x + 1, i + y); 204 } 205 } 206 } 207 seg.solve(1, n, 1); 208 seg2.solve(1, n, 1); 209 for(int i = 2; i < n - 1; i++) { 210 ans += 1ll * seg2.f[i] * seg.f[i + 1]; 211 //printf("ans += %d * %d \n", seg2.f[i], seg.f[i + 1]); 212 } 213 printf("%lld\n", ans); 214 return; 215 } 216 217 int main() { 218 219 for(int i = 2; i < N * 2; i++) { 220 pw[i] = pw[i >> 1] + 1; 221 } 222 int T; 223 scanf("%d", &T); 224 while(T--) { 225 solve(); 226 if(T) { 227 sam.clear(); 228 sam2.clear(); 229 seg.clear(1, n, 1); 230 seg2.clear(1, n, 1); 231 } 232 } 233 return 0; 234 }