題意:給定只有ab的字符串,求其中不連續非空迴文子序列的個數。ide
解:用全部迴文子序列減去迴文子串。spa
容易想到枚舉中線。code
設f[i]表示以i爲中線,迴文字符的個數。那麼迴文子序列就是∑(2f[i] - 1)blog
怎麼求f[i]呢?卷積!字符串
咱們考慮把i變成中線 * 2,那麼f[i] = ( ∑(s[j] == s[i - j]) + 1) / 2get
令a = 1,b = 0,那麼卷積得出的就是全部a的迴文字符數。同理能夠求得全部b的迴文字符數。string
而後用迴文自動機求一遍迴文子串的數目,相減便可。it
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 #include <cmath> 5 6 typedef long long LL; 7 const int N = 100010; 8 const LL MO = 1e9 + 7; 9 const double pi = 3.1415926535897932384626; 10 11 struct cp { 12 double x, y; 13 cp(double X = 0, double Y = 0) { 14 x = X; 15 y = Y; 16 } 17 inline cp operator +(const cp &w) const { 18 return cp(x + w.x, y + w.y); 19 } 20 inline cp operator -(const cp &w) const { 21 return cp(x - w.x, y - w.y); 22 } 23 inline cp operator *(const cp &w) const { 24 return cp(x * w.x - y * w.y, x * w.y + y * w.x); 25 } 26 }a[N << 2]; 27 28 int r[N << 2], f[N << 1]; 29 char s[N]; 30 31 inline LL qpow(LL a, int b) { 32 LL ans = 1; 33 while(b) { 34 if(b & 1) { 35 ans = ans * a % MO; 36 } 37 a = a * a % MO; 38 b = b >> 1; 39 } 40 return ans; 41 } 42 43 inline void FFT(int n, cp *a, int f) { 44 for(int i = 0; i < n; i++) { 45 if(i < r[i]) { 46 std::swap(a[i], a[r[i]]); 47 } 48 } 49 50 for(int len = 1; len < n; len <<= 1) { 51 cp Wn(cos(pi / len), f * sin(pi / len)); 52 for(int i = 0; i < n; i += (len << 1)) { 53 cp w(1, 0); 54 for(int j = 0; j < len; j++) { 55 cp t = a[i + len + j] * w; 56 a[i + len + j] = a[i + j] - t; 57 a[i + j] = a[i + j] + t; 58 w = w * Wn; 59 } 60 } 61 } 62 63 if(f == -1) { 64 for(int i = 0; i <= n; i++) { 65 a[i].x /= n; 66 } 67 } 68 return; 69 } 70 71 namespace pam { 72 int tr[N][26], fail[N], cnt[N], len[N], last, tot; 73 inline void init() { 74 len[1] = -1; 75 fail[0] = fail[1] = 1; 76 tot = last = 1; 77 return; 78 } 79 inline int getfail(int d, int x) { 80 while(s[d - len[x] - 1] != s[d]) { 81 x = fail[x]; 82 } 83 return x; 84 } 85 inline void insert(int d) { 86 int f = s[d] - 'a'; 87 int p = getfail(d, last); 88 if(!tr[p][f]) { 89 ++tot; 90 len[tot] = len[p] + 2; 91 fail[tot] = tr[getfail(d, fail[p])][f]; 92 tr[p][f] = tot; 93 } 94 last = tr[p][f]; 95 cnt[last]++; 96 return; 97 } 98 inline LL count() { 99 LL ans = 0; 100 for(int i = tot; i >= 2; i--) { 101 ans = (ans + cnt[i]) % MO; 102 (cnt[fail[i]] += cnt[i]) %= MO; 103 } 104 return ans; 105 } 106 } 107 108 int main() { 109 scanf("%s", s); 110 int n = strlen(s) - 1; 111 for(int i = 0; i <= n; i++) { 112 a[i].x = (s[i] == 'a'); 113 } 114 int len = 2, lm = 1; 115 while(len <= n + n) { 116 len <<= 1; 117 lm++; 118 } 119 for(int i = 1; i <= len; i++) { 120 r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1)); 121 } 122 123 FFT(len, a, 1); 124 for(int i = 0; i <= len; i++) { 125 a[i] = a[i] * a[i]; 126 } 127 FFT(len, a, -1); 128 for(int i = 0; i <= n + n; i++) { 129 f[i] = ((int)(a[i].x + 0.5) + 1) >> 1; 130 } 131 132 for(int i = 0; i <= n; i++) { 133 a[i].x = (s[i] == 'b'); 134 a[i].y = 0; 135 } 136 for(int i = n + 1; i <= len; i++) { 137 a[i] = cp(0, 0); 138 } 139 FFT(len, a, 1); 140 for(int i = 0; i <= len; i++) { 141 a[i] = a[i] * a[i]; 142 } 143 FFT(len, a, -1); 144 for(int i = 0; i <= n + n; i++) { 145 f[i] += ((int)(a[i].x + 0.5) + 1) >> 1; 146 } 147 148 LL ans = 0; 149 for(int i = 0; i <= n + n; i++) { 150 ans = (ans + qpow(2ll, f[i]) - 1) % MO; 151 } 152 pam::init(); 153 for(int i = 0; i <= n; i++) { 154 pam::insert(i); 155 } 156 ans = (ans - pam::count() + MO) % MO; 157 printf("%lld", ans); 158 return 0; 159 }