[HEOI2016/TJOI2016]求和

Discriptionc++

在2016年,佳媛姐姐剛剛學習了第二類斯特林數,很是開心。函數

如今他想計算這樣一個函數的值:
S(i, j)表示第二類斯特林數,遞推公式爲:
S(i, j) = j ∗ S(i − 1, j) + S(i − 1, j − 1), 1 <= j <= i − 1。
邊界條件爲:S(i, i) = 1(0 <= i), S(i, 0) = 0(1 <= i)
你能幫幫他嗎?
Input

輸入只有一個正整數學習

Output

 輸出f(n)。spa

因爲結果會很大,輸出f(n)對998244353(7 × 17 × 223 + 1)取模的結果便可。code

1 ≤ n ≤ 100000blog

Sample Input
3

Sample Outputip

87it

 

 

   咱們知道第二類斯特林數和排列數(降低冪)組合在一塊兒能夠表示n^k,又由於排列等於組合乘上一個階乘,因而咱們就能夠開開心心的二項式反演,獲得一個某一行(其實也能夠不少行,鑑於這個式子的特殊性質,咱們能夠把不一樣行的同一列合併)某一列的斯特林數的表達式。io

    具體的說,S(k,n) = Σ (i^k / i!) * ((-1)^(n-i) / (n-i)!)     [具體推導就不寫了,就是一個二項式反演]。class

這個式子的特殊性質太多了,首先它是一個卷積的形式,因此咱們能夠直接用NTT 在 N log N 的時間求出某一行的全部第二類斯特林數分別是多少;

而且只有 i^k 項和行數有關,因此同一列很好合並,因而這個題就作完了2333。

 

 

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=330005;
const int ha=998244353;
const int root=3,inv=ha/3+1;
int a[maxn],b[maxn],jc[maxn];
int r[maxn],N,M,n,INV,l;

inline int add(int x,int y){
	x+=y;
	return x>=ha?x-ha:x;
}

inline int ksm(int x,int y){
	int an=1;
	for(;y;y>>=1,x=x*(ll)x%ha) if(y&1) an=an*(ll)x%ha;
	return an;
}

inline void NTT(int *c,const int f){
	for(int i=0;i<N;i++) if(i<r[i]) swap(c[i],c[r[i]]);
	
	for(int i=1;i<N;i<<=1){
		int omega=ksm(f==1?root:inv,(ha-1)/(i<<1));
		for(int p=i<<1,j=0;j<N;j+=p){
			int now=1;
			for(int k=0;k<i;k++,now=now*(ll)omega%ha){
				int x=c[j+k],y=c[j+k+i]*(ll)now%ha;
				c[j+k]=add(x,y);
				c[j+k+i]=add(x,ha-y);
			}
		}
	}
	
	if(f==-1) for(int i=0;i<N;i++) c[i]=c[i]*(ll)INV%ha;
}

inline void init(){
	jc[0]=1;
	for(int i=1;i<=n;i++) jc[i]=jc[i-1]*(ll)i%ha;
	for(int i=0;i<=n;i++){
		if(!i) a[i]=1;
		else if(i==1) a[i]=n+1;
		else a[i]=add(ksm(i,n+1),ha-1)*(ll)ksm(add(i,ha-1)*(ll)jc[i]%ha,ha-2)%ha;
		if(i&1) b[i]=ha-ksm(jc[i],ha-2);
		else b[i]=ksm(jc[i],ha-2);
	}
	
	M=n<<1;
	for(N=1;N<=M;N<<=1) l++;
	for(int i=0;i<N;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
}

inline void solve(){
	NTT(a,1),NTT(b,1);
	for(int i=0;i<N;i++) a[i]=a[i]*(ll)b[i]%ha;
	INV=ksm(N,ha-2),NTT(a,-1);
}

inline void output(){
	int ans=0,base=1;
	for(int i=0;i<=n;i++,base=add(base,base)) ans=add(ans,a[i]*(ll)base%ha*(ll)jc[i]%ha);
	printf("%d\n",ans);
}

int main(){
	scanf("%d",&n);
	init();
	solve();
	output();
	return 0;
}
相關文章
相關標籤/搜索