今天介紹一個神仙算法:Dsu On Tree[ 樹上啓發式合併 ]c++
這個算法用於離線處理詢問子樹信息,並且很好寫。算法
可是在你沒有理解它以前,這是個很鬼畜的算法。數組
理解後你才能真心感到它的美妙之處。緩存
關鍵是它是有着媲美線段樹合併的時間複雜度的「暴力」算法。函數
這裏說一件事,我學這個東西時找了不少篇博客,它們無一例外地給出了這樣一個流程:學習
1. 先統計一個節點全部的輕兒子 而後刪除它的答案
2. 再統計這個節點的重兒子 保留他的答案
3. 再算一遍全部輕兒子 加到答案中上傳優化
我當時就看的很懵逼,算一遍全部輕兒子,刪掉,再算一遍,這不閒的?spa
直接統計它的重兒子再算輕兒子不就行了?很疑惑,問了身邊不少人也都以爲迷惑。code
人類迷惑行爲大賞.jpgblog
後面我搞懂了,爲了避免讓其餘學習dsu on tree的人也以爲迷惑,我就寫了這一篇博客。
在這裏很是感謝洛谷兩個dalao的幫助,如今理解了這個東西。
咱們每次進入一棵子樹計算答案時,都要把計算上一棵子樹的數據清除。
爲何,若是咱們帶着上次計算後的結果去計算新子樹,答案確定是不對的。
可是,咱們最後一棵子樹不須要清除,由於咱們不用再進入新子樹了(沒了)。
那咱們再回到上面說的,爲何一開始要算一遍輕兒子?
從最純粹的暴力開始,咱們有兩個函數dfs1和dfs2,dfs1函數做爲主體函數,dfs2做爲輔助函數。
先dfs1到每個點,dfs1它的後代,計算後代的信息,再dfs2它的後代,計算本身的答案。
也就是說,開始算輕兒子是要把它後代的信息計算出來,而不是理解爲以前提到那些博客裏面的算出來答案後「刪除答案」。刪除答案是爲了避免讓計算的數據衝突。
按照dalao說的,保留重兒子的信息能夠優化複雜度。從原先暴力的O(n^2)優化到O(nlogn)。
因此爲了避免用清除重兒子的信息,先dfs1輕兒子,再dfs1重兒子,最後dfs2輕兒子更新本身的答案。
若是一開始沒有dfs1輕兒子,咱們就沒有獲得後代的信息,所謂「刪除」答案,實際上是從統計信息的數組把dfs1輕兒子時存進去的用於計算的「緩存」清理了。
下面給出代碼:
#include<bits/stdc++.h> #define N 100010 using namespace std; inline int read(){ int data=0,w=1;char ch=0; while(ch!='-' && (ch<'0'||ch>'9'))ch=getchar(); if(ch=='-')w=-1,ch=getchar(); while(ch>='0' && ch<='9')data=data*10+ch-'0',ch=getchar(); return data*w; } struct Edge{ int nxt,to; #define nxt(x) e[x].nxt #define to(x) e[x].to }e[N<<1]; int head[N],tot; inline void addedge(int f,int t){ nxt(++tot)=head[f];to(tot)=t;head[f]=tot; } int cnt[N],siz[N],son[N],c[N],max_val,n,child; long long sum,ans[N]; void add(int x,int f,int val){ cnt[c[x]]+=val; if(cnt[c[x]]>max_val)max_val=cnt[c[x]],sum=c[x]; else if(cnt[c[x]]==max_val)sum+=1LL*c[x]; for(int i=head[x];i;i=nxt(i)){ int y=to(i); if(y==f||y==child)continue; add(y,x,val); } } void dfs1(int x,int f){//重鏈剖分 siz[x]=1;int maxson=-1; for(int i=head[x];i;i=nxt(i)){ int y=to(i); if(y==f)continue; dfs1(y,x); siz[x]+=siz[y]; if(siz[y]>maxson){ maxson=siz[y];son[x]=y; } } } void dfs2(int x,int f,int opt){//opt爲0表示統計後的答案要刪掉,opt爲1則不用刪 for(int i=head[x];i;i=nxt(i)){ int y=to(i); if(y==f)continue; if(y!=son[x])dfs2(y,x,0); }if(son[x])dfs2(son[x],x,1),child=son[x]; add(x,f,1);child=0; ans[x]=sum; if(!opt)add(x,f,-1),sum=0,max_val=0; } int main(){ n=read(); for(int i=1;i<=n;i++)c[i]=read(); for(int i=1;i<n;i++){ int x=read(),y=read(); addedge(x,y);addedge(y,x); } dfs1(1,0);dfs2(1,0,0); for(int i=1;i<=n;i++) printf("%lld ",ans[i]); return 0; }