洛谷P3321 序列統計

氣死了,FFT了半天發現是NTT... 1004535809 這個東西是NTT模數,原根爲3。ide

題意:給定集合,元素的大小不超過M。用這些元素組成長爲n的序列,要求乘積模M爲k,求方案數。spa

n <= 1e9,M是質數code

解:有一個10分的暴力DP...blog

正解首先考慮加起來模M爲k。get

咱們能夠考慮構造多項式f(x),第i項表示i是否在集合內。string

那麼答案就是(f(x))n的第k項係數。多項式快速冪。it

這個東西的長度可能會很長,那麼咱們每次乘完以後就把多於M項的係數所有模到小於M。io

而後這道題求的是乘積不是和....原根!event

由於質數都有原根,而後兩個數相乘等於原根指數相加。作完了。class

注意原根的指數模的是φ(M)而不是M。還有千萬不要用FFT,由於會爆double

  1 #include <cstdio>
  2 #include <algorithm>
  3 #include <cstring>
  4 #include <cmath>
  5 
  6 typedef long long LL;
  7 const int N = 8010;
  8 const LL MO = 1004535809, g = 3;
  9 
 10 int r[N << 2], b[N], vis[N];
 11 LL f[N << 2], ans[N << 2], a[N << 2], c[N << 2];
 12 
 13 inline LL qpow(LL a, LL b) {
 14     LL ans = 1;
 15     while(b) {
 16         if(b & 1) {
 17             ans = ans * a % MO;
 18         }
 19         a = a * a % MO;
 20         b = b >> 1;
 21     }
 22     return ans;
 23 }
 24 
 25 inline void NTT(int n, LL *a, int f) {
 26     for(int i = 0; i < n; i++) {
 27         if(i < r[i]) {
 28             std::swap(a[i], a[r[i]]);
 29         }
 30     }
 31     for(int len = 1; len < n; len <<= 1) {
 32         LL Wn = qpow(g, (MO - 1) / (len << 1));
 33         if(f == -1) {
 34             Wn = qpow(Wn, MO - 2);
 35         }
 36         for(int i = 0; i < n; i += (len << 1)) {
 37             LL w = 1;
 38             for(int j = 0; j < len; j++) {
 39                 LL t = a[i + len + j] * w % MO;
 40                 a[i + len + j] = (a[i + j] - t + MO) % MO;
 41                 a[i + j] = (a[i + j] + t) % MO;
 42                 w = w * Wn % MO;
 43             }
 44         }
 45     }
 46     if(f == -1) {
 47         LL inv = qpow(n, MO - 2);
 48         for(int i = 0; i <= n; i++) {
 49             a[i] = a[i] * inv % MO;
 50         }
 51     }
 52     return;
 53 }
 54 
 55 inline int getG(int x) {
 56     for(int i = 1; i < x; i++) {
 57         memset(vis, 0, x * sizeof(int));
 58         int now = 1, fd = 0;
 59         for(int t = 1; t < x; t++) {
 60             now = now * i % x;
 61             if(vis[now]) {
 62                 fd = 1;
 63                 break;
 64             }
 65             vis[now] = t;
 66         }
 67         if(!fd) {
 68             return i;
 69         }
 70     }
 71     return -1;
 72 }
 73 
 74 int main() {
 75     int n, m, mod, k;
 76     scanf("%d%d%d%d", &n, &mod, &k, &m);
 77     for(int i = 1; i <= m; i++) {
 78         scanf("%d", &b[i]);
 79     }
 80     getG(mod);
 81     for(int i = 1; i <= m; i++) {
 82         b[i] = vis[b[i]];
 83         if(b[i]) {
 84             f[b[i]] = 1;
 85         }
 86     }
 87     k = vis[k];
 88 
 89     int len = 2, lm = 1;
 90     while(len <= ((mod - 1) << 1)) {
 91         len <<= 1;
 92         lm++;
 93     }
 94     for(int i = 1; i <= len; i++) {
 95         r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1));
 96     }
 97 
 98     /*printf("g = %d \n", g);
 99     for(int i = 1; i <= m; i++) {
100         printf("%d ", b[i]);
101     }
102     puts("");
103     for(int i = 0; i < mod; i++) {
104         printf("%d ", f[i]);
105     }
106     puts("\n");*/
107 
108     ans[0] = 1;
109     while(n) {
110         if(n & 1) {
111             for(int i = 0; i <= len; i++) {
112                 a[i] = f[i];
113                 c[i] = ans[i];
114             }
115             NTT(len, a, 1);
116             NTT(len, c, 1);
117             for(int i = 0; i <= len; i++) {
118                 c[i] = a[i] * c[i] % MO;
119             }
120             NTT(len, c, -1);
121             for(int i = 0; i <= len; i++) {
122                 ans[i] = c[i];
123             }
124             for(int i = len; i >= mod; i--) {
125                 (ans[i - mod + 1] += ans[i]) %= MO;
126                 ans[i] = 0;
127             }
128         }
129         for(int i = 0; i <= len; i++) {
130             a[i] = f[i];
131         }
132         NTT(len, a, 1);
133         for(int i = 0; i <= len; i++) {
134             a[i] = a[i] * a[i] % MO;
135         }
136         NTT(len, a, -1);
137         for(int i = 0; i <= len; i++) {
138             f[i] = a[i];
139         }
140         for(int i = len; i >= mod; i--) {
141             (f[i - mod + 1] += f[i]) %= MO;
142             f[i] = 0;
143         }
144         n >>= 1;
145         /*printf("f   : ");
146         for(int i = 0; i < mod; i++) {
147             printf("%d ", f[i]);
148         }
149         printf("\nans : ");
150         for(int i = 0; i < mod; i++) {
151             printf("%d ", ans[i]);
152         }
153         puts("");*/
154     }
155 
156     printf("%lld", ans[k]);
157     return 0;
158 }
AC代碼
相關文章
相關標籤/搜索