LOJ:https://loj.ac/problem/2537c++
洛谷:https://www.luogu.org/problemnew/show/P5298git
不按期詐屍優化
很久沒敲代碼了犯了好多sb錯誤spa
考慮一個暴力的\(dp\),首先這題只用到了權值的大小關係,因此咱們先離散化,設\(f_{x,i}\)表示\(x\)點權值爲\(i\)的機率。code
轉移很顯然:
\[ f_{x,i}=f_{ls,i}\left(\sum_{j=1}^{i-1}p_x\cdot f_{rs,j}+\sum_{j=i+1}^{m}(1-p_x)\cdot f_{rs,j}\right)+f_{rs,i}\left(\sum_{j=1}^{i-1}p_x\cdot f_{ls,j}+\sum_{j=i+1}^{m}(1-p_x)\cdot f_{ls,j}\right) \]
就是枚舉當前是選最大值仍是最小值,前綴和優化能夠作到\(O(n^2)\)。遞歸
而後咱們開權值線段樹維護這個東西,每次線段樹合併,合併的時候處理前綴和來優化。get
時間複雜度\(O(n\log ^2 n)\),空間複雜度\(O(n\log n)\)。it
#include<bits/stdc++.h> using namespace std; void read(int &x) { x=0;int f=1;char ch=getchar(); for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f; for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f; } void print(int x) { if(x<0) putchar('-'),x=-x; if(!x) return ;print(x/10),putchar(x%10+48); } void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');} #define lf double #define ll long long #define pii pair<int,int > #define vec vector<int > #define mid ((l+r)>>1) #define pb push_back #define mp make_pair #define fr first #define sc second #define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++) const int maxn = 4e5+10; const int inf = 1e9; const lf eps = 1e-8; const int mod = 998244353; int qpow(int a,int x) { int res=1; for(;x;x>>=1,a=1ll*a*a%mod) if(x&1) res=1ll*res*a%mod; return res; } int n,son[maxn][2],w[maxn],m,p,b[maxn]; int ls[maxn<<5],rs[maxn<<5],rt[maxn],s[maxn<<5],tag[maxn<<5],seg; void push(int x,int c) {s[x]=1ll*s[x]*c%mod,tag[x]=1ll*tag[x]*c%mod;} void pushdown(int x) { if(tag[x]!=1) push(ls[x],tag[x]),push(rs[x],tag[x]),tag[x]=1; } void insert(int &x,int l,int r,int c) { if(!x) x=++seg;s[x]=tag[x]=1; if(l==r) return ; if(c<=mid) insert(ls[x],l,mid,c); else insert(rs[x],mid+1,r,c); } int merge(int x,int y,int lsum=0,int rsum=0) { if(!x) return push(y,lsum),y; if(!y) return push(x,rsum),x; int t=++seg;tag[t]=1;pushdown(x),pushdown(y); int sl=s[ls[x]],sr=s[ls[y]]; // 注意這裏遞歸的時候會被改掉,我就被坑了很久... ls[t]=merge(ls[x],ls[y],(lsum+1ll*(1-p+mod)*s[rs[x]]%mod)%mod,(rsum+1ll*(1-p+mod)*s[rs[y]]%mod)%mod); rs[t]=merge(rs[x],rs[y],(lsum+1ll*p*sl%mod)%mod,(rsum+1ll*p*sr%mod)%mod); s[t]=(s[ls[t]]+s[rs[t]])%mod;return t; } int solve(int x) { if(!son[x][0]) return insert(rt[x],1,m,lower_bound(b+1,b+m+1,w[x])-b),rt[x]; int l=solve(son[x][0]); if(!son[x][1]) return l; int r=solve(son[x][1]);p=w[x]; return merge(l,r); } int calc(int x,int l,int r) { if(!x) return 0; if(l==r) return 1ll*l*b[l]%mod*s[x]%mod*s[x]%mod; pushdown(x); return (calc(ls[x],l,mid)+calc(rs[x],mid+1,r))%mod; } int main() { read(n);int I=qpow(10000,mod-2),x; FOR(i,1,n) read(x),son[x][0]?son[x][1]=i:son[x][0]=i; FOR(i,1,n) read(x),son[i][0]?w[i]=1ll*x*I%mod:b[++m]=w[i]=x; sort(b+1,b+m+1);write(calc(solve(1),1,m)); return 0; }