「優美的暴力」——樹上啓發式合併

今天介紹一個神仙算法:Dsu On Tree[ 樹上啓發式合併 ]c++

這個算法用於離線處理詢問子樹信息,並且很好寫。算法

可是在你沒有理解它以前,這是個很鬼畜的算法。數組

理解後你才能真心感到它的美妙之處。緩存

關鍵是它是有着媲美線段樹合併的時間複雜度的「暴力」算法。函數

這裏說一件事,我學這個東西時找了不少篇博客,它們無一例外地給出了這樣一個流程:學習

1. 先統計一個節點全部的輕兒子 而後刪除它的答案
2. 再統計這個節點的重兒子 保留他的答案
3. 再算一遍全部輕兒子 加到答案中上傳
優化

我當時就看的很懵逼,算一遍全部輕兒子,刪掉,再算一遍,這不閒的?spa

直接統計它的重兒子再算輕兒子不就行了?很疑惑,問了身邊不少人也都以爲迷惑。code

人類迷惑行爲大賞.jpgblog

後面我搞懂了,爲了避免讓其餘學習dsu on tree的人也以爲迷惑,我就寫了這一篇博客。

在這裏很是感謝洛谷兩個dalao的幫助,如今理解了這個東西。

咱們每次進入一棵子樹計算答案時,都要把計算上一棵子樹的數據清除。

爲何,若是咱們帶着上次計算後的結果去計算新子樹,答案確定是不對的。

可是,咱們最後一棵子樹不須要清除,由於咱們不用再進入新子樹了(沒了)。

那咱們再回到上面說的,爲何一開始要算一遍輕兒子?

從最純粹的暴力開始,咱們有兩個函數dfs1和dfs2,dfs1函數做爲主體函數,dfs2做爲輔助函數。

先dfs1到每個點,dfs1它的後代,計算後代的信息,再dfs2它的後代,計算本身的答案。

也就是說,開始算輕兒子是要把它後代的信息計算出來,而不是理解爲以前提到那些博客裏面的算出來答案後「刪除答案」。刪除答案是爲了避免讓計算的數據衝突。

按照dalao說的,保留重兒子的信息能夠優化複雜度。從原先暴力的O(n^2)優化到O(nlogn)。

因此爲了避免用清除重兒子的信息,先dfs1輕兒子,再dfs1重兒子,最後dfs2輕兒子更新本身的答案。

若是一開始沒有dfs1輕兒子,咱們就沒有獲得後代的信息,所謂「刪除」答案,實際上是從統計信息的數組把dfs1輕兒子時存進去的用於計算的「緩存」清理了。

下面給出代碼:

#include<bits/stdc++.h>
#define N 100010
using namespace std;
inline int read(){
    int data=0,w=1;char ch=0;
    while(ch!='-' && (ch<'0'||ch>'9'))ch=getchar();
    if(ch=='-')w=-1,ch=getchar();
    while(ch>='0' && ch<='9')data=data*10+ch-'0',ch=getchar();
    return data*w;
}
struct Edge{
    int nxt,to;
    #define nxt(x) e[x].nxt
    #define to(x) e[x].to
}e[N<<1];
int head[N],tot;
inline void addedge(int f,int t){
    nxt(++tot)=head[f];to(tot)=t;head[f]=tot;
}
int cnt[N],siz[N],son[N],c[N],max_val,n,child;
long long sum,ans[N];
void add(int x,int f,int val){
    cnt[c[x]]+=val;
    if(cnt[c[x]]>max_val)max_val=cnt[c[x]],sum=c[x];
    else if(cnt[c[x]]==max_val)sum+=1LL*c[x];
    for(int i=head[x];i;i=nxt(i)){
        int y=to(i);
        if(y==f||y==child)continue;
        add(y,x,val);
    }
}
void dfs1(int x,int f){//重鏈剖分
    siz[x]=1;int maxson=-1;
    for(int i=head[x];i;i=nxt(i)){
        int y=to(i);
        if(y==f)continue;
        dfs1(y,x);
        siz[x]+=siz[y];
        if(siz[y]>maxson){
            maxson=siz[y];son[x]=y;
        }
    }
}
void dfs2(int x,int f,int opt){//opt爲0表示統計後的答案要刪掉,opt爲1則不用刪
    for(int i=head[x];i;i=nxt(i)){
        int y=to(i);
        if(y==f)continue;
        if(y!=son[x])dfs2(y,x,0);
    }if(son[x])dfs2(son[x],x,1),child=son[x];
    add(x,f,1);child=0;
    ans[x]=sum;
    if(!opt)add(x,f,-1),sum=0,max_val=0;
}
int main(){
    n=read();
    for(int i=1;i<=n;i++)c[i]=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        addedge(x,y);addedge(y,x);
    }
    dfs1(1,0);dfs2(1,0,0);
    for(int i=1;i<=n;i++)
        printf("%lld ",ans[i]);
    return 0;
}
相關文章
相關標籤/搜索