嘟嘟嘟
我認爲這題是黑題的緣由是質數不是\(998244353\),因此得用三模NTT或是拆係數FFT。我抄了一個拆係數FFT的板子,但如今暫時仍是不很懂。
但這不影響解題思路。
首先\(n> K\)無解。(徹底搞不懂\(n\)那麼大幹啥)
咱們令\(dp[i][j]\)表示第\(i\)個數有\(j\)個\(1\)時的方案數。咱們先不考慮\(1\)的位置,這樣答案就是\(\sum _ {i = 1} ^ {K} dp[n][i] * \binom{n}{i}\)。轉移也很顯然,由於是\(or\)操做,因此原來是\(1\)的位置填不填\(1\)都行,即\(dp[i][j] = \sum _ {k = 0} ^ {j}dp[i - 1][k] * \binom{j}{k} * 2 ^ k\)。把組合數拆開後能夠用FFT加速。那麼如今能夠在\(O(k ^ 2logk)\)的時間內解決了。
如今的瓶頸在於只能一位一位的dp,得循環\(n\)次。咱們想辦法優化:考慮兩個dp方程的合併,就能得出
\[ dp[x + y][i] = \sum _{j = 0} ^ {i}dp[x][j] * dp[y][i - j] *\binom{i}{j} *2 ^{y * j} \]
而後把它整理一下:
\[ \frac{dp[x + y][i]}{i!} = \sum _ {j = 0} ^ {i} \frac{dp[x][j] * (2 ^ y) ^ j}{j!} * \frac{dp[y][i - j]}{(i - j)!} \]
因而愉快的多項式快速冪就能夠啦!
代碼裏的過程量都是\(\frac{dp[i][j]}{i!}\),到最後再乘上\(i!\)便可。ios
#include<cstdio> #include<iostream> #include<cmath> #include<algorithm> #include<cstring> #include<cstdlib> #include<cctype> #include<vector> #include<queue> #include<assert.h> #include<ctime> using namespace std; #define enter puts("") #define space putchar(' ') #define Mem(a, x) memset(a, x, sizeof(a)) #define In inline #define forE(i, x, y) for(int i = head[x], y; (y = e[i].to) && ~i; i = e[i].nxt) typedef long long ll; typedef long double db; const int INF = 0x3f3f3f3f; const db eps = 1e-8; const int maxn = 3e5 + 5; const ll mod = 1e9 + 7; const db PI = acos(-1); In ll read() { ll ans = 0; char ch = getchar(), las = ' '; while(!isdigit(ch)) las = ch, ch = getchar(); while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar(); if(las == '-') ans = -ans; return ans; } In void write(ll x) { if(x < 0) x = -x, putchar('-'); if(x >= 10) write(x / 10); putchar(x % 10 + '0'); } In void MYFILE() { #ifndef mrclr freopen("ha.in", "r", stdin); freopen("hia.out", "w", stdout); #endif } ll n; int m; ll fac[maxn], inv[maxn], p2[maxn]; In ll inc(ll a, ll b) {return a + b < mod ? a + b : a + b - mod;} In ll C(int n, int m) { if(m > n) return 0; return fac[n] * inv[m] % mod * inv[n - m] % mod; } In ll quickpow(ll a, ll b) { ll ret = 1; for(; b; b >>= 1, a = a * a % mod) if(b & 1) ret = ret * a % mod; return ret; } In void init() { fac[0] = inv[0] = p2[0] = 1; for(int i = 1; i < maxn; ++i) fac[i] = fac[i - 1] * i % mod; inv[maxn - 1] = quickpow(fac[maxn - 1], mod - 2); for(int i = maxn - 2; i; --i) inv[i] = inv[i + 1] * (i + 1) % mod; for(int i = 1; i < maxn; ++i) p2[i] = (p2[i - 1] * 2) % mod; } int len = 1, lim = 0, rev[maxn << 2]; struct Comp { db x, y; In Comp operator + (const Comp& oth)const { return (Comp){x + oth.x, y + oth.y}; } In Comp operator - (const Comp& oth)const { return (Comp){x - oth.x, y - oth.y}; } In Comp operator * (const Comp& oth)const { return (Comp){x * oth.x - y * oth.y, x * oth.y + y * oth.x}; } friend In void swap(Comp& A, Comp& B) { swap(A.x, B.x), swap(A.y, B.y); } friend In Comp operator ! (Comp a) { return (Comp){a.x, -a.y}; } }A[maxn << 2], B[maxn << 2]; Comp dftA[maxn << 2], dftB[maxn << 2], dftC[maxn << 2], dftD[maxn << 2]; In void fft(Comp* a, int len, int flg) { for(int i = 0; i < len; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]); for(int i = 1; i < len; i <<= 1) { Comp omg = (Comp){cos(PI / i), sin(PI / i) * flg}; for(int j = 0; j < len; j += (i << 1)) { Comp o = (Comp){1, 0}; for(int k = 0; k < i; ++k, o = o * omg) { Comp tp1 = a[j + k], tp2 = o * a[j + k + i]; a[j + k] = tp1 + tp2, a[j + k + i] = tp1 - tp2; } } } } const ll NUM = 32767; In void FFT(ll* a, ll* b, ll* c, int n) { for(int i = 0; i < len; ++i) { ll tp1 = i <= n ? a[i] : 0; ll tp2 = i <= n ? b[i] : 0; A[i] = (Comp){tp1 & NUM, tp1 >> 15}; B[i] = (Comp){tp2 & NUM, tp2 >> 15}; } fft(A, len, 1), fft(B, len, 1); for(int i = 0; i < len; ++i) { int j = (len - i) & (len - 1); Comp da = (A[i] + (!A[j])) * (Comp){0.5, 0}; Comp db = (A[i] - (!A[j])) * (Comp){0, -0.5}; Comp dc = (B[i] + (!B[j])) * (Comp){0.5, 0}; Comp dd = (B[i] - (!B[j])) * (Comp){0, -0.5}; dftA[i] = da * dc, dftB[i] = da * dd; dftC[i] = db * dc, dftD[i] = db * dd; } for(int i = 0; i < len; ++i) { A[i] = dftA[i] + dftB[i] * (Comp){0, 1}; B[i] = dftC[i] + dftD[i] * (Comp){0, 1}; } fft(A, len, -1), fft(B, len, -1); for(int i = 0; i <= n; ++i) { ll da = (ll)(A[i].x / len + 0.5) % mod; ll db = (ll)(A[i].y / len + 0.5) % mod; ll dc = (ll)(B[i].x / len + 0.5) % mod; ll dd = (ll)(B[i].y / len + 0.5) % mod; c[i] = inc(inc(da, ((db + dc) << 15) % mod), (dd << 30) % mod); } } ll dp[maxn], f[maxn], ta[maxn], LEN = 1; In void mul(ll* a, ll* b) { ll tp = 1; for(int i = 0; i <= m; ++i) { ta[i] = a[i] * tp % mod; tp = tp * p2[LEN] % mod; } FFT(ta, b, ta, m); for(int i = 0; i <= m; ++i) a[i] = ta[i] % mod; } In ll QuickPow(ll n) { for(int i = 1; i <= m; ++i) dp[i] = inv[i]; f[0] = 1; LEN = 1; for(; n; n >>= 1, mul(dp, dp), LEN <<= 1) if(n & 1) mul(f, dp); ll ret = 0; for(int i = 1; i <= m; ++i) ret = inc(ret, f[i] * fac[i] % mod* C(m, i) % mod); return ret; } int main() { // MYFILE(); init(); n = read(), m = read(); if(n > m) {puts("0"); return 0;} while(len <= (m << 1)) len <<= 1, ++lim; for(int i = 0; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lim - 1)); write(QuickPow(n)), enter; return 0; }
這題昨晚調了半天仍是WA,回宿舍後,多項式大師 \(S \color{red}{jk}\)看了一眼後說道,你把double改爲long double,就過了。而後還真就過了,神啊!git