LOJ:https://loj.ac/problem/2541c++
很巧妙的思路。git
注意到運行的過程當中機率的分母在不停的變化,這樣會讓咱們很很差算,咱們考慮這樣轉化:假設全部人都活着,而後隨機選一我的,若是此人已死那就從新選一次。函數
假設當前活着的人集合爲\(T\),那麼射中第\(i\)我的的機率就是:
\[ \sum_{i=0}^{\infty}\left(\frac{s_{all}-s_T}{s_{all}}\right)^i\frac{w_i}{s_{all}}=\frac{w_i}{s_T} \]
其中\(s_p\)表示\(p\)集合的\(w\)總和,能夠發現這樣選的機率和原來是同樣的。優化
咱們考慮容斥,設\(f(T)\)表示至少\(T\)集合的人比\(1\)號後死,用一個很簡單的容斥能夠獲得:
\[ ans=\sum_{T}(-1)^{|T|}f(T) \]
那麼大力算能夠獲得\(f\):
\[ \begin{align}f(T)&=\sum_{i=0}^{\infty}\left(\frac{s_{all}-s_T-w_1}{s_{all}}\right)^i\cdot \frac{w_1}{s_{all}}\\&=\frac{w_1}{w_1+s_T}\end{align} \]
答案就是:
\[ ans=\sum_T(-1)^{|T|}\frac{w_1}{w_1+s_T} \]
注意到\(s\)至多隻有\(1e5\),咱們能夠揹包算出每一個\(s_T\)出現了多少次,揹包的時候順便把容斥係數帶上。spa
這樣作是\(O(ns)\)的,顯然\(T\)掉了。code
可是咱們能夠用生成函數優化這個東西,直接就是:
\[ \prod_{i=2}^{n}(1-x^{w_i}) \]
而後分治\(FFT\)優化就行了,複雜度\(O(n\log ^2 n)\)。get
#include<bits/stdc++.h> using namespace std; void read(int &x) { x=0;int f=1;char ch=getchar(); for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f; for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f; } void print(int x) { if(x<0) putchar('-'),x=-x; if(!x) return ;print(x/10),putchar(x%10+48); } void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');} #define lf double #define ll long long #define pii pair<int,int > #define vec vector<int > #define pb push_back #define mp make_pair #define fr first #define sc second #define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++) const int maxn = 4e5+10; const int inf = 1e9; const lf eps = 1e-8; const int mod = 998244353; int w[maxn],pos[maxn],N,bit,f[maxn],a[maxn],s[maxn],n,mxn; int add(int x,int y) {return x+y>=mod?x+y-mod:x+y;} int del(int x,int y) {return x-y<0?x-y+mod:x-y;} int mul(int x,int y) {return 1ll*x*y-1ll*x*y/mod*mod;} int qpow(int a,int x) { int res=1; for(;x;x>>=1,a=mul(a,a)) if(x&1) res=mul(res,a); return res; } void prepare(int t) { for(N=1,bit=0;N<=t;N<<=1,bit++);mxn=N;w[0]=1,w[1]=qpow(3,(mod-1)/mxn); for(int i=2;i<=N;i++) w[i]=mul(w[i-1],w[1]); } void ntt_get(int t) { for(N=1,bit=0;N<=t;N<<=1,bit++); for(int i=1;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1)); } void ntt(int *r,int op) { for(int i=1;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]); for(int i=1,d=mxn>>1;i<N;i<<=1,d>>=1) for(int j=0;j<N;j+=i<<1) for(int k=0;k<i;k++) { int x=r[j+k],y=mul(r[i+j+k],w[k*d]); r[j+k]=add(x,y),r[i+j+k]=del(x,y); } if(op==-1) { reverse(r+1,r+N);int d=qpow(N,mod-2); for(int i=0;i<N;i++) r[i]=mul(r[i],d); } } int get(int lt,int rt) { int l=lt,r=rt,mid,ans=lt; while(l<=r) { mid=(l+r)>>1; if(s[rt]-s[mid]>=s[mid]-s[lt-1]) l=mid+1,ans=mid; else r=mid-1; }return ans; } void solve(int l,int r,int *t) { if(l>r) return ; if(l==r) {t[0]=1,t[a[l]]=mod-1;return ;} int d=1<<((int)ceil(log2(s[r]-s[l-1]))+1); int *sl=new int [d+10],*sr=new int [d+10],mid=get(l,r); for(int i=0;i<=d+5;i++) sl[i]=sr[i]=0; solve(l,mid,sl),solve(mid+1,r,sr); ntt_get(d>>1);ntt(sl,1),ntt(sr,1); for(int i=0;i<N;i++) t[i]=mul(sl[i],sr[i]); ntt(t,-1);delete sl;delete sr; } int main() { read(n);for(int i=1;i<=n;i++) read(a[i]),s[i]=s[i-1]+a[i]; prepare(s[n]<<1);solve(2,n,f);int ans=0; for(int i=0;i<=s[n];i++) ans=add(ans,mul(qpow(a[1]+i,mod-2),f[i])); write(mul(ans,a[1])); return 0; }