題庫連接php
求c++
\[f(n)=\sum_{i=0}^n\sum_{j=0}^i S(i,j)\times 2^j \times (j!)\]優化
\(S(i, j)\) 表示第二類斯特林數。對 \(998244353\) 取模。ui
\(1\leq n\leq 100000\)spa
因爲 \(S(i,j)=0,i\leq j\) ,咱們能夠把式子改寫成code
\[f(n)=\sum_{i=0}^n\sum_{j=0}^n S(i,j)\times 2^j \times (j!)\]ip
那麼get
\[f(n)=\sum_{j=0}^n 2^j \times (j!)\times\sum_{i=0}^n S(i,j)\]it
把 \(S(i, j)\) 的通項公式代入io
\[\begin{aligned}f(n)&=\sum_{j=0}^n 2^j \times (j!)\times\sum_{i=0}^n \sum_{k=0}^j\frac{(-1)^k}{k!}\frac{(j-k)^i}{(j-k)!}\\&=\sum_{j=0}^n 2^j \times (j!)\times\sum_{k=0}^j\frac{(-1)^k}{k!}\frac{\sum\limits_{i=0}^n(j-k)^i}{(j-k)!}\end{aligned}\]
記
\[\begin{aligned}A(x)&=\sum_{i=0}^\infty \frac{(-1)^i}{i!}x^i\\B(x)&=\sum_{i=0}^\infty\frac{\sum\limits_{k=0}^ni^k}{i!}x^i\end{aligned}\]
那麼
\[f(n)=\sum_{j=0}^n 2^j\times(j!)\times(A\otimes B)(j)\]
\(\text{NTT}\) 優化便可。
#include <bits/stdc++.h> using namespace std; const int N = 100000*4, yzh = 998244353; int n, inv[N+5], a[N+5], b[N+5], len, L, R[N+5]; int quick_pow(int a, int b) { int ans = 1; while (b) { if (b&1) ans = 1ll*ans*a%yzh; b >>= 1, a = 1ll*a*a%yzh; } return ans; } void NTT(int *A, int o) { for (int i = 0; i < len; i++) if (i < R[i]) swap(A[i], A[R[i]]); for (int i = 1; i < len; i <<= 1) { int gn = quick_pow(3, (yzh-1)/(i<<1)), x, y; if (o == -1) gn = quick_pow(gn, yzh-2); for (int j = 0; j < len; j += (i<<1)) { int g = 1; for (int k = 0; k < i; k++, g = 1ll*g*gn%yzh) { x = A[j+k], y = 1ll*g*A[j+k+i]%yzh; A[j+k] = (x+y)%yzh, A[j+k+i] = (x-y)%yzh; } } } } void work() { scanf("%d", &n); inv[0] = inv[1] = 1; for (int i = 2; i <= n; i++) inv[i] = -1ll*yzh/i*inv[yzh%i]%yzh; for (int i = 1; i <= n; i++) inv[i] = 1ll*inv[i]*inv[i-1]%yzh; for (int i = 0; i <= n; i++) if (i&1) a[i] = -inv[i]; else a[i] = inv[i]; b[0] = 1; b[1] = n+1; for (int i = 2; i <= n; i++) b[i] = 1ll*inv[i]*(quick_pow(i, n+1)-1)%yzh*quick_pow(i-1, yzh-2)%yzh; for (len = 1; len <= (n<<1); len <<= 1) L++; for (int i = 0; i < len; i++) R[i] = (R[i>>1]>>1)|((i&1)<<(L-1)); NTT(a, 1), NTT(b, 1); for (int i = 0; i < len; i++) a[i] = 1ll*a[i]*b[i]%yzh; NTT(a, -1); for (int i = 0, inv = quick_pow(len, yzh-2); i < len; i++) a[i] = 1ll*a[i]*inv%yzh; int ans = 0; for (int i = 0, ad = 1; i <= n; i++, ad = 2ll*ad%yzh*i%yzh) (ans += 1ll*a[i]*ad%yzh) %= yzh; printf("%d\n", (ans+yzh)%yzh); } int main() {work(); return 0; }