dsu on treec++
點我跳轉spa
給你一棵以\(1\)爲根節點,包含\(n\)個節點的樹和一個參數 \(k\),求每一個節點的"\(rating\)"code
\(rating\) 值的計算方式是這樣的,對於\(u\)的子樹中的全部節點,若是\(x,y\)知足\(dis(x,y) = k\)ci
而且\(x,y\)的最近公共祖先是\(u\)且知足\(u != x , u != y\),那麼\(u\)的\(rating\)就會增長\(a_x + a_y\)get
由於 \(x , y\) 的最近公共祖先爲 \(u\) ,因此 \(x , y\) 必定在 \(u\) 子樹的不一樣分支it
\(dis(x,y) = k\) 等價於 \(dep[x] + dep[y] - 2 × dep[lca(x,y)] = k\)class
因而能夠先用 cnt[dep] 記錄深度爲 dep 的節點出現的次數,用 sum[dep] 記錄 dep 的節點的權值和test
那麼對於 \(rt\) 爲根的子節點 \(u\),與其相匹配的點的深度爲 \(d = k + 2 * dep[rt] - dep[u]\)統計
它對 \(rt\) 產生的貢獻就爲 \(cnt_d × a_u\) ,而深度爲 \(d\) 的點由於 \(u\) 的出現對 \(rt\) 的貢獻都會翻倍di
因此 \(u\) 節點的出現對 \(rt\) 的總貢獻爲 \(sum[d] + a[u] * cnt[d]\)
又由於與 \(u\) 節點產生貢獻的節點必須和 \(u\) 不在一個分支,即一個分支內的任意節點不能相互影響
因此須要先對一個分支統計完貢獻後,再添加它的信息
#include<bits/stdc++.h> #define rep(i,a,n) for (int i=a;i<=n;i++) #define int long long using namespace std; const int N = 3e5 + 10; struct Edge{ int nex , to; }edge[N << 1]; int head[N] , TOT; void add_edge(int u , int v) { edge[++ TOT].nex = head[u] ; edge[TOT].to = v; head[u] = TOT; } int n , k , sum[N] , cnt[N] , ans[N]; int dep[N] , sz[N] , HH , hson[N] , f[N][30] , a[N]; void dfs(int u , int far) { sz[u] = 1; dep[u] = dep[far] + 1; for(int i = head[u] ; i ; i = edge[i].nex) { int v = edge[i].to; if(v == far) continue ; dfs(v , u); sz[u] += sz[v]; if(sz[v] > sz[hson[u]]) hson[u] = v; } } void change(int u , int far , int val) { sum[dep[u]] += val * a[u]; cnt[dep[u]] += val; for(int i = head[u] ; i ; i = edge[i].nex) { int v = edge[i].to; if(v == far || v == HH) continue ; change(v , u , val); } } void calc(int u , int far , int rt) { int c = k + 2 * dep[rt] - dep[u]; if(c < 0) return ; ans[rt] += sum[c] + a[u] * cnt[c]; for(int i = head[u] ; i ; i = edge[i].nex) { int v = edge[i].to; if(v == far || v == HH) continue ; calc(v , u , rt); } } void dsu(int u , int far , int op) { for(int i = head[u] ; i ; i = edge[i].nex) { int v = edge[i].to; if(v == far || v == hson[u]) continue ; dsu(v , u , 0); } if(hson[u]) dsu(hson[u] , u , 1) , HH = hson[u]; for(int i = head[u] ; i ; i = edge[i].nex) { int v = edge[i].to; if(v == far || v == HH) continue ; calc(v , u , u) , change(v , u , 1); } HH = 0; sum[dep[u]] += a[u] , cnt[dep[u]] ++ ; if(!op) { change(u , far , -1); } } signed main() { cin >> n >> k; rep(i , 1 , n) cin >> a[i]; rep(i , 2 , n) { int u , v; cin >> u >> v; add_edge(u , v) , add_edge(v , u); } dfs(1 , 0); dsu(1 , 0 , 0); rep(i , 1 , n) cout << ans[i] << " \n"[i == n]; return 0; }