BZOJ 4818ide
感受不難。ui
首先轉化一下題目,「至少有一個質數」$=$「所有方案」$ - $「一個質數也沒有」。this
注意到$m \leq 2e7$,$[1, m]$內的質數能夠直接篩出來。spa
設$f_{i, j}$表示當前長度序列爲$i$,當前和模$p$的值是$j$的方案數,直接無腦枚舉$m$轉移複雜度是$O(nmp)$的,可是發現每一次轉移形式都是相同的。code
$$f_{i, x} = \sum f_{i - 1, y}(y + z \equiv x(\mod p))$$blog
其實在模$p$的意義下大於等於$p$的數能夠直接歸類到這個數模$p$這一檔裏面,也就是說,咱們能夠記一個$cnt_x$表示模$p$意義下相同的數有$x$個。string
$$f_{i, (x + y) \mod p} = \sum f_{i - 1, x} \times cnt_y$$it
發現這個式子的形式很像矩陣快速冪的樣子,而後就把轉移寫成矩陣的形式快速冪一下就行了。io
轉移矩陣的第$(i, j)$個格子是$\sum_{(i + k) \equiv j(\mod p)}cnt_k$class
時間複雜度$O(m + p^3logn)$。
咕,感受時間剛恰好。
然而再次觀察一下這個式子發現是一個卷積的形式,所以能夠直接$NTT$,時間複雜度能夠降到$O(m + plogplogn)$,可是在這題中$p$過小了$ + $模數很差,直接暴力卷積的時間表現應該比$NTT$要優秀。
Code:
#include <cstdio> #include <cstring> using namespace std; typedef long long ll; const int N = 2e7 + 5; const int M = 100; const ll P = 20170408LL; int n, m, K, pCnt = 0, pri[N], cnt[M]; bool np[N]; template <typename T> inline void inc(T &x, T y) { x += y; if (x >= P) x -= P; } template <typename T> inline void sub(T &x, T y) { x -= y; if (x < 0) x += P; } struct Matrix { int tn, tm; ll s[M][M]; inline void init() { tn = tm = 0; memset(s, 0, sizeof(s)); } friend Matrix operator * (const Matrix x, const Matrix y) { Matrix res; res.init(); res.tn = x.tn, res.tm = y.tm; for (int k = 0; k < x.tm; k++) for (int i = 0; i < x.tn; i++) for (int j = 0; j < y.tm; j++) inc(res.s[i][j], x.s[i][k] * y.s[k][j] % P); return res; } inline Matrix fpow(int y) { Matrix x = *this, res; res.init(); res.tn = x.tn, res.tm = x.tm; for (int i = 0; i < x.tn; i++) res.s[i][i] = 1; for (; y ; y >>= 1) { if (y & 1) res = res * x; x = x * x; } return res; } inline void print() { for (int i = 0; i < tn; i++) for (int j = 0; j < tm; j++) printf("%lld%c", s[i][j], " \n"[j == tm - 1]); printf("\n"); } } trans, ans; inline void sieve() { np[1] = 1; for (int i = 2; i <= m; i++) { if (!np[i]) pri[++pCnt] = i; for (int j = 1; j <= pCnt && pri[j] * i <= m; j++) { np[i * pri[j]] = 1; if (i % pri[j] == 0) break; } } } inline ll solve1() { memset(cnt, 0, sizeof(cnt)); for (int i = 1; i <= m; i++) ++cnt[i % K]; trans.init(); trans.tn = trans.tm = K; for (int i = 0; i < K; i++) for (int j = 0; j < K; j++) inc(trans.s[i][(i + j) % K], 1LL * cnt[j]); // trans.print(); trans = trans.fpow(n); // trans.print(); ans.init(); ans.s[0][0] = 1; ans.tn = 1, ans.tm = K; ans = ans * trans; return ans.s[0][0]; } inline ll solve2() { sieve(); memset(cnt, 0, sizeof(cnt)); for (int i = 1; i <= m; i++) if (np[i]) ++cnt[i % K]; /* for (int i = 0; i < K; i++) printf("%d%c", cnt[i], " \n"[i == K - 1]); */ trans.init(); trans.tn = trans.tm = K; for (int i = 0; i < K; i++) for (int j = 0; j < K; j++) inc(trans.s[i][(i + j) % K], 1LL * cnt[j]); // trans.print(); trans = trans.fpow(n); // trans.print(); ans.init(); ans.s[0][0] = 1; ans.tn = 1, ans.tm = K; ans = ans * trans; return ans.s[0][0]; } int main() { scanf("%d%d%d", &n, &m, &K); // printf("%lld\n", solve1()); // printf("%lld\n", solve2()); printf("%lld\n", (solve1() - solve2() + P) % P); return 0; }