第二類斯特林數學習筆記

第二類$ Stirling$數是把包含n個元素的集合劃分爲正好k個非空子集的方法的數目。   
遞推公式爲$ S(n,k) = S(n-1,k-1) + kS(n-1,k).$ios


這類斯特林數有一個很好的性質:git

$ x^k=\sum\limits_{j=0}^kC_x^jS(k,j)j!$spa

其意義是$ k$個球放入$ x$個有標號盒子的方案數,枚舉空盒的數量,乘上階乘以及選出這些空盒的方案便可code


$ stirling$數能夠經過組合意義展開:blog

$ S(n,m)= \frac{1}{m!}*\sum\limits_{k=0}^m(-1)^k(m-k)^nC_m^k$get

咱們枚舉空盒$ k$的個數,就能夠容斥出剛好劃分紅$ m$個非空子集的方案數string

除$ m!$是由於劃分的集合是無標號集合而容斥的集合是帶標號集合it


這個式子能夠化爲卷積形式:io

$ S(n,m)= \frac{1}{m!}*\sum\limits_{k=0}^m(-1)^k(m-k)^nC_m^k$class

$ S(n,m)= \frac{1}{m!}*\sum\limits_{k=0}^m(-1)^k(m-k)^n\frac{m!}{k!(m-k)!}$

$ S(n,m)= \sum\limits_{k=0}^m(-1)^k(m-k)^n\frac{1}{k!(m-k)!}$

$ S(n,m)= \sum\limits_{k=0}^m\frac{(-1)^k}{k!}*\frac{(m-k)^n}{(m-k)!}$

令$ A(x)=\frac{(-1)^k}{k!},B(x)=\frac{m^n}{m!}$

則有$ S(n,m)=\sum\limits_{k=0}^mA(k)B(m-k)$

是一個經典的卷積模型,能夠$ FFT/NTT$在$ O(n log n)$的時間複雜度計算$ S(n,0..n)$


例題:

$ 1$.給定一棵$ n$個節點的樹,每一個點有$ yd[i]$的機率原地不動,不然以均等的機率往周圍的點移動。問每一個點到根的移動次數的$ k$次方的指望

來源:聯考模擬賽

數據範圍:$ n,k<=100000,nk<=1000000$


$ solution$

考慮$ DP$,用$ val[i]^j$表示點$ i$到根的路徑上的$ j$次方值的指望

每次枚舉每一個點回到本身的機率以及除了回到本身的狀況之外的常數貢獻(爲防止父親對本身產生影響先忽略父親影響部分的值)

最後從上往下加上父親對本身的貢獻

發現複雜度瓶頸是從$ val[i]^j$推出$ (val[i]+1)^j$

這部分只能二項式展開致使咱們不得不枚舉$ k$而後每次暴力二項式展開$ O(k)$轉移

總複雜度$ O(nk^2)$


考慮轉化成第二類斯特林數

咱們把$ val[i]^j$二項式展開獲得$ val[i]^j=\sum\limits_{k=0}^jC_{val[i]}^kS(j,k)k!$

咱們只須要維護每一個點到根的$ C_{val[x]}^j \ (0<=j<=k)$便可

因爲組合數有$ C_{val[x]+1}^j=C_{val[x]}^{j}+C_{val[x]}^{j-1}$

所以這能夠作到$ O(1)$轉移,時間複雜度$ O(nk+k^2)$其中$ k^2$是求斯特林數的複雜度

因爲咱們只須要求$ S(k,0...k)$,咱們能夠直接用上面的式子$ NTT$在$ O(n log n)$的時間內求出

總複雜度$ O(nk+k log k)$,能夠經過此題


 

$ my \ code$

#include<ctime>
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<queue>
#define p 998244353
#define M 200010
#define rt register int
#define ll long long
using namespace std;
inline ll read(){
    ll x = 0; char zf = 1; char ch = getchar();
    while (ch != '-' && !isdigit(ch)) ch = getchar();
    if (ch == '-') zf = -1, ch = getchar();
    while (isdigit(ch)) x = x * 10 + ch - '0', ch = getchar(); return x * zf;
}
void write(ll y){if(y<0)putchar('-'),y=-y;if(y>9)write(y/10);putchar(y%10+48);}
void writeln(const ll y){write(y);putchar('\n');}
int i,j,k,m,n,x,y,z,cnt;
int F[M],L[M],N[M],a[M],fa[M],d[M];
ll yd[M],inv[1000010];int val[100010][2];
void add(int x,int y){
    a[++k]=y;
    if(!F[x])F[x]=k;
    else N[L[x]]=k;
    L[x]=k;
}
void dfs(int x,int pre){
    fa[x]=pre;
    for(rt i=F[x];i;i=N[i])if(a[i]!=pre)dfs(a[i],x);
}
ll ksm(ll x,ll y){
    if(!y)return 1;ll ew=1;
    while(y>1){
        if(y&1)y--,ew=x*ew%p;
        y>>=1,x=x*x%p;
    }return x*ew%p;
}
inline int calc(int x,int y){
    return val[x][y&1^1];
}
ll stop[M];
ll without(int x,int y){
    ll ans=(yd[x]*calc(x,y)+(1ll-yd[x])*inv[d[x]]%p*calc(fa[x],y))%p;
    for(rt i=F[x];i;i=N[i])if(a[i]!=fa[x])(ans+=(1ll-yd[x])*inv[d[x]]%p*(without(a[i],y)+calc(a[i],y)))%=p;
    return val[x][y&1]=ans*stop[x]%p;
}
void solve(int x,int y){
    if(x!=1)(val[x][y&1]+=val[fa[x]][y&1])%=p;
    for(rt i=F[x];i;i=N[i])if(a[i]!=fa[x])solve(a[i],y);
}
ll jc[100010],S[100010],ans[100010];
struct poly{
    int n,m,lim;
    ll a[2200010],b[2200010];int R[2200010];
    void init(int nn,int mm){
        n=m=mm;a[0]=1;b[0]=0;
        for(rt i=1;i<=mm;i++){
            a[i]=ksm(jc[i],p-2);
            if(i&1)a[i]=p-a[i];
            b[i]=ksm(i,nn)*ksm(jc[i],p-2)%p;
        }
        lim=1;while(lim<=n+m)lim<<=1;
        for(rt i=1;i<lim;i++)R[i]=(R[i>>1]>>1)|((i&1)*(lim>>1));
    }
    ll ksm(ll x,ll y){
        if(!y)return 1;ll ew=1;
        while(y>1){
            if(y&1)y--,ew=x*ew%p;
            y>>=1,x=x*x%p;
        }
        return x*ew%p;
    }
    void NTT(ll *A,int fla){
        for(rt i=0;i<lim;i++)if(i>R[i])swap(A[i],A[R[i]]);
        for(rt i=1;i<lim;i<<=1){
            ll w=ksm(3,998244352/2/i);
            if(fla==-1)w=ksm(w,p-2);
            for(rt j=0;j<lim;j+=i<<1){
                ll K=1;
                for(rt k=0;k<i;k++,K=K*w%p){
                    const ll x=A[j+k],y=K*A[i+j+k]%p;
                    A[j+k]=(x+y)%p;A[i+j+k]=(x-y)%p;
                }
            }
        }
    }
    void main(int nn,int mm){
        init(nn,mm);
        NTT(a,1);NTT(b,1);
        for(rt i=0;i<lim;i++)a[i]=a[i]*b[i]%p;
        NTT(a,-1);
        for(rt i=0;i<=n;i++)S[i]=(a[i]*ksm(lim,p-2)%p+p)%p;
    }
}NTT;
int main(){ 
    n=read();m=read();inv[0]=inv[1]=jc[0]=jc[1]=1;
    if(m==0){
        for(rt i=2;i<=n;i++)writeln(1);
        return 0;
    }   
    for(rt i=2;i<=n;i++)inv[i]=inv[p%i]*(p-p/i)%p;
    for(rt i=2;i<=m;i++)jc[i]=jc[i-1]*i%p;
    NTT.main(m,m);
 
    for(rt i=1;i<n;i++){
        x=read();y=read();
        d[x]++;d[y]++;
        add(x,y);add(y,x);
    }
    for(rt i=1;i<=n;i++)val[i][0]=1;
 
    dfs(1,1);
    for(rt i=2;i<=n;i++)yd[i]=read()*ksm(1000000,p-2)%p;
    for(rt x=2;x<=n;x++){
        stop[x]=yd[x];
        for(rt i=F[x];i;i=N[i])if(a[i]!=fa[x])(stop[x]+=(1ll-yd[x])*inv[d[x]]%p)%=p;
        (stop[x]+=p)%=p;        
    }
    for(rt i=2;i<=n;i++)stop[i]=ksm(p+1-stop[i],p-2);
    for(rt i=0;i<=m;i++){
        if(i){
            for(rt j=1;j<=n;j++)val[j][i&1]=0;
            for(rt j=F[1];j;j=N[j])without(a[j],i);
            solve(1,i);
        }
        for(rt j=1;j<=n;j++)(ans[j]+=val[j][i&1]*S[i]%p*jc[i])%=p;
    }
     
    for(rt i=2;i<=n;i++)writeln((ans[i]+p)%p);
    return 0;
}
相關文章
相關標籤/搜索