有一個 \(n\times m\) 的矩陣 \(A\),每一個元素都是 \([0,1]\) 內的等機率隨機實數,記 \(s_i=\sum_{j=1}^mA_{i,j}\),求 \(\lfloor\min s_i\rfloor^k\) 的指望。dom
對 \(998244353\) 取模。spa
\(n\leq {10}^9,m\leq 5\times {10}^5,k\leq {10}^9\)code
咱們只用求 \(\lfloor s_i\rfloor\) 爲 \(0\) 到 \(m-1\) 中每一個值的機率就行了。get
記 \(b_i=\sum_{j=1}^iA_{1,j}-\lfloor\sum_{j=1}^iA_{1,j}\rfloor,c_i=\lfloor\sum_{j=1}^iA_{1,j}\rfloor\),那麼 \(b_i\) 也在 \([0,1]\) 間等機率隨機。咱們能夠直接忽略 \(b_i\) 相同的狀況。這樣就能夠把 \(b\) 當作一個排列。string
能夠發現,\(c_i>c_{i-1}\) 當且僅當 \(b_i<b_{i-1}\)。it
那麼只用對於每一個 \(i\) 計算有多少種 \(c_j>c_{j-1}\) 的個數爲 \(i\) 的狀況就行了。記這個東西爲 \(A_{m,i}\)。io
怎麼算呢?function
那麼 \(\frac{1}{n!}\sum_{i=0}^mA_{n,i}\) 爲 \(x_1+x_2+\ldots+x_n\leq m+1(0\leq x_i\leq 1)\) 的機率class
記 \(h_n(x)\) 爲 \(x_1+x_2+\ldots+x_n\leq x(x_i\geq 0)\) 的機率。im
那麼有
\[ h_1(x)=x\\ h_i(x)=\int_0^xh_{i-1}(x-z)~dz=\int_0^xh_{i-1}(z)~dz=\frac{x^i}{i!} \]
枚舉有多少個 \(x_i>1\) 進行容斥,那麼就有:
\[ \begin{align} \frac{1}{n!}\sum_{i=0}^mA_{n,i}&=\sum_{i=0}^{m+1}{(-1)}^i\binom{n}{i}h_n(m+1-i)\\ \frac{1}{n!}A_{n,m}&=\sum_{i=0}^{m+1}{(-1)}^i\binom{n}{i}h_n(m+1-i)-\sum_{i=0}^{m}{(-1)}^i\binom{n}{i}h_n(m-i)\\ &=\sum_{i=0}^{m+1}{(-1)}^i\binom{n}{i}h_n(m+1-i)+\sum_{i=0}^{m+1}{(-1)}^i\binom{n}{i-1}h_n(m+1-i)\\ &=\sum_{i=0}^{m+1}{(-1)}^i\binom{n+1}{i}h_n(m+1-i)\\ &=\frac{1}{n!}\sum_{i=0}^{m+1}{(-1)}^i\binom{n+1}{i}{(m+1-i)}^n\\ A_{n,m}&=\sum_{i=0}^{m+1}{(-1)}^i\binom{n+1}{i}{(m+1-i)}^n \end{align} \]
這樣就能夠在 \(O(m\log m)\) 內計算出 \(A_{m,0}\ldots A_{m,m}\) 了。
時間複雜度:\(O(m\log m)\)
#include<cstdio> #include<cstring> #include<algorithm> #include<cstdlib> #include<ctime> #include<functional> #include<cmath> #include<vector> #include<assert.h> //using namespace std; using std::min; using std::max; using std::swap; using std::sort; using std::reverse; using std::random_shuffle; using std::lower_bound; using std::upper_bound; using std::unique; using std::vector; typedef long long ll; typedef unsigned long long ull; typedef double db; typedef std::pair<int,int> pii; typedef std::pair<ll,ll> pll; void open(const char *s){ #ifndef ONLINE_JUDGE char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout); #endif } void open2(const char *s){ #ifdef DEBUG char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout); #endif } int rd(){int s=0,c,b=0;while(((c=getchar())<'0'||c>'9')&&c!='-');if(c=='-'){c=getchar();b=1;}do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9');return b?-s:s;} void put(int x){if(!x){putchar('0');return;}static int c[20];int t=0;while(x){c[++t]=x%10;x/=10;}while(t)putchar(c[t--]+'0');} int upmin(int &a,int b){if(b<a){a=b;return 1;}return 0;} int upmax(int &a,int b){if(b>a){a=b;return 1;}return 0;} const int N=1200000; const ll p=998244353; ll fp(ll a,ll b) { ll s=1; for(;b;b>>=1,a=a*a%p) if(b&1) s=s*a%p; return s; } namespace ntt { const int W=1048576; int rev[N]; ll w[N]; void init() { ll s=fp(3,(p-1)/W); w[0]=1; for(int i=1;i<W/2;i++) w[i]=w[i-1]*s%p; } void ntt(ll *a,int n,int t) { for(int i=1;i<n;i++) { rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0); if(rev[i]>i) swap(a[i],a[rev[i]]); } for(int i=2;i<=n;i<<=1) for(int j=0;j<n;j+=i) for(int k=0;k<i/2;k++) { ll u=a[j+k]; ll v=a[j+k+i/2]*w[W/i*k]; a[j+k]=(u+v)%p; a[j+k+i/2]=(u-v)%p; } if(t==-1) { reverse(a+1,a+n); ll inv=fp(n,p-2); for(int i=0;i<n;i++) a[i]=a[i]*inv%p; } } void mul(ll *a,ll *b,ll *c,int n,int m,int l) { static ll a1[N],a2[N]; int k=1; while(k<=n+m) k<<=1; for(int i=0;i<k;i++) a1[i]=a2[i]=0; for(int i=0;i<=n;i++) a1[i]=a[i]; for(int i=0;i<=m;i++) a2[i]=b[i]; ntt::ntt(a1,k,1); ntt::ntt(a2,k,1); for(int i=0;i<k;i++) a1[i]=a1[i]*a2[i]%p; ntt::ntt(a1,k,-1); for(int i=0;i<=l;i++) c[i]=a1[i]; } } ll inv[N],fac[N],ifac[N]; int n,m,k; ll f[N]; ll a[N],b[N],c[N]; ll binom(int x,int y) { return fac[x]*ifac[y]%p*ifac[x-y]%p; } int main() { open("b"); ntt::init(); inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1; for(int i=2;i<=500010;i++) { inv[i]=-p/i*inv[p%i]%p; fac[i]=fac[i-1]*i%p; ifac[i]=ifac[i-1]*inv[i]%p; } scanf("%d%d%d",&n,&m,&k); for(int i=0;i<=m+1;i++) { a[i]=(i&1?-1:1)*ifac[i]%p*ifac[m+1-i]%p; b[i]=fp(i,m); } ntt::mul(a,b,c,m+1,m+1,m+1); for(int i=0;i<m;i++) f[i]=c[i+1]*fac[m+1]%p; for(int i=0;i<m;i++) f[i]=f[i]*ifac[m]%p; for(int i=m-1;i>=0;i--) f[i]=(f[i]+f[i+1])%p; for(int i=0;i<m;i++) f[i]=fp(f[i],n)%p; for(int i=0;i<m;i++) f[i]=(f[i]-f[i+1])%p; ll ans=0; for(int i=0;i<m;i++) ans=(ans+fp(i,k)*f[i])%p; ans=(ans%p+p)%p; printf("%lld\n",ans); return 0; }