【學習筆記】dsu on tree

我也不知道爲啥這要起這名,完徹底全沒看到並查集的影子啊……html

實際上原理就是一個樹上的啓發式合併。c++

特色是能夠在$O(nlogn)$的時間複雜度內完成對無修改的子樹的統計,複雜度優於莫隊算法。算法

侷限性也很明顯:1.不能支持修改  2.只能支持子樹統計,不能鏈上統計。(鏈上統計你不能直接樹剖嗎?)優化

那麼它是怎麼實現的呢?首先有一個例子:
樹上每一個節點都有一個顏色(那麼必定是藍色)spa

求每一個節點的子樹上有多少顏色爲k的節點。(每一個節點的k不必定相同).net

$O(n^2)$的算法很是好想,以每一個點爲起點dfs一下就沒了。code

固然也有不那麼暴力的作法,dfs序一下再主席樹或者莫隊隨便搞搞也行。htm

那麼咱們先看看暴力是怎麼作的。blog

每次統計x節點前,暴力將x的子樹的貢獻加入,統計結束後,再暴力刪除貢獻,消除影響。get

咱們發現這有不少無用的刪除操做,考慮優化?

那麼咱們怎麼用dsu上樹優雅的解決這個問題呢?咱們想到了樹鏈剖分(輕重鏈剖分)。

具體的作法是,咱們先統計一個點的輕兒子,再把它的影響消除。再統計重兒子,此時沒必要消除影響。

爲了完成統計,最後再統計一遍輕兒子。

能夠這麼考慮:只有dfs到輕邊時,纔會將輕邊的子樹中合併到上一級的重鏈,

樹鏈剖分將一棵樹分割成了不超過$logn$條重鏈。
每個節點最多向上合併$logn$次,單次修改複雜度$O(1)$。
因此總體複雜度是$O(nlogn)$的。

因此大概的模版是這樣的:

 1 void dfs2(int u,int f,int k){
 2     for(int i=head[u];i;i=G[i].next){
 3         int v=G[i].v;if(v==f||v==wson[u])continue;
 4         dfs2(v,u,0);
 5     }
 6     if(wson[u])dfs(wson[u],u,1),now=wson[u];
 7     calc(u,f,1);
 8     now=0;ans[u]=sum;
 9     if(k==0)calc(u,f,-1),sumv=0,maxv=0;
10 }

下面是兩道爛大街的例題:

1. Lomsat gelral(cf600E)

n個點的有根樹,以1爲根,每一個點有一種顏色。咱們稱一種顏色佔領了一個子樹當且僅當沒有其餘顏色在這個子樹中出現得比它多。求佔領每一個子樹的全部顏色之和。

就是剛纔的裸題啊。

 1 #include<bits/stdc++.h>
 2 #define N 700010
 3 using namespace std;
 4 struct Edge{int u,v,next;}G[2*N];
 5 typedef long long ll;
 6 int n,c[N],val[N],size[N],wson[N],fa[N];
 7 ll ans[N];
 8 int head[4*N],tot=0;
 9 void addedge(int u,int v){
10     G[++tot].u=u;G[tot].v=v;G[tot].next=head[u];head[u]=tot;
11     G[++tot].u=v;G[tot].v=u;G[tot].next=head[v];head[v]=tot;
12 }
13 void dfs1(int u,int f=0){
14     size[u]=1;
15     for(int i=head[u];i;i=G[i].next){
16         int v=G[i].v;if(v==f)continue;
17         if(v==f)continue;
18         dfs1(v,u);
19         size[u]+=size[v];
20         if(size[v]>size[wson[u]])wson[u]=v;
21     }
22 }
23 bool vis[N];int maxv=0;ll sum=0;
24 void change(int u,int f,int k){
25     c[val[u]]+=k;
26     if(k>0&&c[val[u]]>=maxv){
27         if(c[val[u]]>maxv)sum=0,maxv=c[val[u]];
28         sum+=val[u];
29     }
30     for(int i=head[u];i;i=G[i].next){
31         int v=G[i].v;if(v==f||vis[v])continue;
32         change(v,u,k);
33     }
34 }
35 void dfs2(int u,int f=0,bool used=0){
36     for(int i=head[u];i;i=G[i].next){
37         int v=G[i].v;if(v==f||v==wson[u])continue;
38         dfs2(v,u);
39     }
40     if(wson[u])dfs2(wson[u],u,1),vis[wson[u]]=1;
41     change(u,f,1);ans[u]=sum;
42     if(wson[u])vis[wson[u]]=0;
43     if(!used)change(u,f,-1),maxv=sum=0;
44 }
45 inline int read(){
46     int f=1,x=0;char ch;
47     do{ch=getchar();if(ch=='-')f=-1;}while(ch<'0'||ch>'9');
48     do{x=x*10+ch-'0';ch=getchar();}while(ch>='0'&&ch<='9');
49     return f*x;
50 }
51 int main(){
52     n=read();
53     for(int i=1;i<=n;i++)val[i]=read();
54     for(int i=1;i<n;i++){
55         int u=read(),v=read();
56         addedge(u,v);
57     }
58     dfs1(1);dfs2(1);
59     for(int i=1;i<=n;i++)printf("%I64d ",ans[i]);
60 }

固然這題也有不這麼作的作法,隨便從cf上粘了一個,你們自行意會……

 1 #include<bits/stdc++.h>
 2 #define N 100005
 3 using namespace std;
 4 vector<int>a[N];map<int,int>S[N];
 5 int F[N],id[N],c[N],n,i,x,y;
 6 long long G[N],ans[N];
 7 void work(int x,int y,int color){
 8     if (y>F[x]) F[x]=y,G[x]=0;
 9     if (y==F[x]) G[x]+=color;
10 }
11 void Union(int &x,int y){
12     if (S[x].size()<S[y].size()) swap(x,y);
13     for (map<int,int>::iterator it=S[y].begin();it!=S[y].end();it++)
14         work(x,S[x][it->first]+=it->second,it->first);
15 }
16 void DFS(int x,int fa){
17     id[x]=x;S[x][c[x]]=1;
18     F[x]=1;G[x]=c[x];
19     for (int i=0,y;i<a[x].size();i++)
20         if ((y=a[x][i])!=fa)
21             DFS(y,x),Union(id[x],id[y]);
22     ans[x]=G[id[x]];
23 }
24 int main(){
25     scanf("%d",&n);
26     for (i=1;i<=n;i++)
27         scanf("%d",&c[i]);
28     for (i=1;i<n;i++)
29         scanf("%d%d",&x,&y),
30         a[x].push_back(y),
31         a[y].push_back(x);
32     DFS(1,0);
33     for (i=1;i<=n;i++)
34         printf("%I64d ",ans[i]);
35 }

例2: Arpa's letter-marked tree and Mehrdad's Dokhtar-kosh paths(CF741D)

這題也很顯然,若是重排後能造成迴文串,那麼出現奇數次的字符應該少於2個(即最多1個)若是隻有a~v的話考慮把每一個字符當作一個二進制位,把一個點i到根的路徑異或值記爲s[i],那麼咱們就是要對於每一個x在子樹中找到a和b,使得s[a]^s[b]爲0或2的次冪,且dep[a]+dep[b]-dep[lca]*2最大。

 1 #include<bits/stdc++.h>
 2 #define N 500005
 3 using namespace std;
 4 int size[N],head[4*N],tot=0,wson[N],s[N],f[20*N],ans[N],d[N],a[N];
 5 char c[N];
 6 int maxv,n,inf;
 7 struct Edge{int u,v,next;}G[2*N];
 8 void addedge(int u,int v){
 9     G[++tot].u=u;G[tot].v=v;G[tot].next=head[u];head[u]=tot;
10     //G[++tot].u=v;G[tot].v=u;G[tot].next=head[v];head[v]=tot;
11 }
12 void dfs1(int u,int fa){
13     size[u]=1;d[u]=d[fa]+1;
14     if(u!=1)s[u]=s[fa]^(1<<a[u]);
15     for(int i=head[u];i;i=G[i].next){
16         int v=G[i].v;
17         dfs1(v,u);
18         size[u]+=size[v];if(size[v]>size[wson[u]])wson[u]=v;
19     }
20 }
21 void calc(int rt,int u){
22     int now=s[u];
23     maxv=max(maxv,f[now]+d[u]-2*d[rt]);
24     if((s[u]^s[rt])==0)maxv=max(maxv,d[u]-d[rt]);
25     for(int i=0;i<22;i++){
26         now=(1<<i)^s[u];
27         maxv=max(maxv,f[now]+d[u]-2*d[rt]);
28         if((s[u]^s[rt])==(1<<i))maxv=max(maxv,d[u]-d[rt]);
29     }
30     for(int i=head[u];i;i=G[i].next){
31         int v=G[i].v;calc(rt,v);
32     }
33 }
34 void change(int u,int k){
35     if(k)f[s[u]]=max(f[s[u]],d[u]);
36     else f[s[u]]=inf;
37     for(int i=head[u];i;i=G[i].next)change(G[i].v,k);
38 }
39 void dfs2(int u,int k){
40     for(int i=head[u];i;i=G[i].next){
41         int v=G[i].v;if(v==wson[u])continue;
42         dfs2(v,0);
43     }
44     if(wson[u])dfs2(wson[u],1);
45     maxv=0;int now=s[u];
46     maxv=max(maxv,f[now]-d[u]);
47     for(int i=0;i<22;i++){
48         now=(1<<i)^s[u];
49         maxv=max(maxv,f[now]-d[u]);
50     }
51     for(int i=head[u];i;i=G[i].next){
52         int v=G[i].v;if(v==wson[u])continue;
53         calc(u,v);change(v,1);
54     }
55     ans[u]=maxv;
56     if(!k){
57         for(int i=head[u];i;i=G[i].next)change(G[i].v,0);
58         f[s[u]]=inf;
59     }else f[s[u]]=max(f[s[u]],d[u]);
60 }
61 void erase(int u){
62     for(int i=head[u];i;i=G[i].next){
63         int v=G[i].v;erase(v);
64         ans[u]=max(ans[u],ans[v]);
65     }
66 }
67 int main(){
68     scanf("%d",&n);
69     for(int i=2;i<=n;i++){
70         int u;scanf("%d %c\n",&u,&c[i]);
71         addedge(u,i);a[i]=c[i]-'a';
72     }
73     dfs1(1,0);
74     memset(f,128,sizeof(f));inf=f[0];dfs2(1,0);
75     erase(1);
76     for (int i=1;i<=n;++i)printf("%d%c",ans[i]," \n"[i==n]);
77 }

大概是這樣。

參考:

http://blog.csdn.net/qq_35392050/article/details/64537364

http://www.cnblogs.com/zzqsblog/p/6146916.html

相關文章
相關標籤/搜索