樹鏈剖分就是將樹劃分爲多條鏈,將每條鏈映射到序列上,而後使用線段樹,平衡樹等數據結構來維護每條鏈的信息。c++
樹剖將樹鏈映射到序列上後用線段樹等數據結構來維護樹鏈信息,因此能夠像區間修改,區間查詢同樣進行樹上路徑的修改、查詢等操做數據結構
重兒子:子樹結點數目最多的兒子(size最大的點);ui
重邊:父親結點和重兒子連成的邊;spa
重鏈:由多條重邊鏈接而成的路徑;3d
輕兒子: 除了重兒子,其他都爲輕兒子;code
輕邊:重邊以外的邊;blog
紅圈表示重兒子;ip
黑邊表示重邊;input
由黑邊連成的鏈即爲重鏈it
第一遍dfs處理出重兒子,深度,父親等信息
第二遍dfs處理出結點所在重鏈的鏈頂,dfs序
上圖中的樹處理完畢後dfs序以下
能夠發現,因爲咱們優先dfs重兒子,因此重兒子結點的編號是連續的,因而一條重鏈就被映射成了一段連續的區間;
這樣樹上兩點間的路徑就被分割成了多個連續的區間,如
(12,8) 可分割成 12,2 - 6, 1 – 4, 8 四段
(11,13)可分紅 2-6-11,1 – 4 -9 -13 兩段
因而咱們就可使用線段樹來進行樹鏈修改與查詢了
每次都將路徑分割成多個區間,區間操做能夠在O(logn)內解決, 但若是分割成的區間數不少怎麼辦?!
那咱們就來看下兩點間的路徑最多會分割成多少段。
這時候重兒子就發揮做用了,因爲每條鏈的鏈頂都是一個輕兒子,輕兒子的大小確定小於重兒子, 因此size[輕兒子]<=size[父親]/2
這樣從上往下每進入一條新鏈,結點的個數就會除2,因此通過的鏈數就是log級別的了。
樹剖每次將路徑分紅log個區間,而後區間操做通常都會用到線段樹之類的數據結構來維護,因此通常狀況下一次操做的時間複雜度爲(logn)^2
一棵樹上有n個節點,編號分別爲1到n,每一個節點都有一個權值w。咱們將如下面的形式來要求你對這棵樹完成一些操做:
I. CHANGE u t : 把結點u的權值改成t
II. QMAX u v: 詢問從點u到點v的路徑上的節點的最大權值 I
II. QSUM u v: 詢問從點u到點v的路徑上的節點的權值和
注意:從點u到點v的路徑上的節點包括u和v自己
輸入的第一行爲一個整數n,表示節點的個數。接下來n – 1行,每行2個整數a和b,表示節點a和節點b之間有一條邊相連。接下來n行,每行一個整數,第i行的整數wi表示節點i的權值。接下來1行,爲一個整數q,表示操做的總數。接下來q行,每行一個操做,以「CHANGE u t」或者「QMAX u v」或者「QSUM u v」的形式給出。
對於100%的數據,保證1<=n<=30000,0<=q<=200000;中途操做中保證每一個節點的權值w在-30000到30000之間。
對於每一個「QMAX」或者「QSUM」的操做,每行輸出一個整數表示要求輸出的結果。
4 1 2 2 3 4 1 4 2 1 3 12 QMAX 3 4 QMAX 3 3 QMAX 3 2 QMAX 2 3 QSUM 3 4 QSUM 2 1 CHANGE 1 5 QMAX 3 4 CHANGE 3 6 QMAX 3 4 QMAX 2 4 QSUM 3 4
4 1 2 2 10 6 5 6 5 16
樹剖模板題,樹鏈剖分後用線段樹維護便可
#include <bits/stdc++.h> #define lson (o << 1) #define rson (o << 1 | 1) using namespace std; const int N = 3e4 + 10; typedef long long ll; vector<int> G[N]; const ll inf = 1e9; int n; int val[N]; int fa[N]; int son[N]; int sze[N]; int dep[N]; void dfs1(int u, int f) { sze[u] = 1; fa[u] = f; son[u] = 0; dep[u] = dep[f] + 1; for (int i = 0; i < G[u].size(); i++) { int v = G[u][i]; if (v == f) continue; dfs1(v, u); sze[u] += sze[v]; if (sze[v] > sze[son[u]]) son[u] = v; } } int top[N]; int cnt; int pos[N]; int a[N]; void dfs2(int u, int f, int t) { top[u] = t; pos[u] = ++cnt; a[cnt] = val[u]; if (son[u]) dfs2(son[u], u, t); for (int i = 0; i < G[u].size(); i++) { int v = G[u][i]; if (v == f || v == son[u]) continue; dfs2(v, u, v); } } ll sumv[N << 2]; ll maxv[N << 2]; void pushup(int o) { sumv[o] = sumv[lson] + sumv[rson]; maxv[o] = max(maxv[lson], maxv[rson]); } void build(int o, int l, int r) { if (l == r) { sumv[o] = a[l]; maxv[o] = a[l]; return; } int mid = (l + r) >> 1; build(lson, l, mid); build(rson, mid + 1, r); pushup(o); } void update(int o, int l, int r, int pos, ll v) { if (l == r) { sumv[o] = v; maxv[o] = v; return; } int mid = (l + r) >> 1; if (pos <= mid) update(lson, l, mid, pos, v); else update(rson, mid + 1, r, pos, v); pushup(o); } ll querysum(int o, int l, int r, int ql, int qr) { if (ql <= l && r <= qr) { return sumv[o]; } ll ans = 0; int mid = (l + r) >> 1; if (ql <= mid) ans += querysum(lson, l, mid, ql, qr); if (qr > mid) ans += querysum(rson, mid + 1, r, ql, qr); return ans; } ll querymax(int o, int l, int r, int ql, int qr) { if (ql <= l && r <= qr) { return maxv[o]; } ll ans = -inf; int mid = (l + r) >> 1; if (ql <= mid) ans = max(ans, querymax(lson, l, mid, ql, qr)); if (qr > mid) ans = max(ans, querymax(rson, mid + 1, r, ql, qr)); return ans; } ll calcsum(int u, int v) { ll ans = 0; while (top[u] != top[v]) {//當不在同一條鏈上 if (dep[top[u]] < dep[top[v]]) swap(u, v);//每次深度較大的點向上走 ans += querysum(1, 1, n, pos[top[u]], pos[u]); u = fa[top[u]];//進入新的鏈 } if (dep[u] < dep[v]) swap(u, v);//進入同一條鏈再求一次 ans += querysum(1, 1, n, pos[v], pos[u]); return ans; } ll calcmax(int u, int v) { ll ans = -inf; while (top[u] != top[v]) { if (dep[top[u]] < dep[top[v]]) swap(u, v); ans = max(ans, querymax(1, 1, n, pos[top[u]], pos[u])); u = fa[top[u]]; } if (dep[u] < dep[v]) swap(u, v); ans = max(ans, querymax(1, 1, n, pos[v], pos[u])); return ans; } int main() { //freopen("in.txt", "r", stdin); //freopen("out.txt", "w", stdout); scanf("%d", &n); for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } for (int i = 1; i <= n; i++) scanf("%d", &val[i]); dep[0] = 0; dfs1(1, 0); cnt = 0; dfs2(1, 0, 1); build(1, 1, n); int m; scanf("%d", &m); char ch[10]; for (int i = 1; i <= m; i++) { scanf("%s", ch); int l, r, k; ll v; switch(ch[1]) { case 'M': scanf("%d%d", &l, &r); printf("%lld\n", calcmax(l, r)); break; case 'S': scanf("%d%d", &l, &r); printf("%lld\n", calcsum(l, r)); break; case 'H': scanf("%d%lld", &k, &v); update(1, 1, n, pos[k], v); break; } } return 0; }