【數據結構】樹鏈剖分詳細講解

 「在一棵樹上進行路徑的修改、求極值、求和」乍一看只要線段樹就能輕鬆解決,實際上,僅憑線段樹是不能搞定它的。咱們須要用到一種貌似高級的複雜算法——樹鏈剖分。php

樹鏈剖分是把一棵樹分割成若干條鏈,以便於維護信息的一種方法,其中最經常使用的是重鏈剖分(Heavy Path Decomposition,重路徑分解),因此通常提到樹鏈剖分或樹剖都是指重鏈剖分。除此以外還有長鏈剖分和實鏈剖分等,本文暫不介紹。html

首先咱們須要明確概念:

  • 重兒子:父親節點的全部兒子中子樹結點數目最多(size最大)的結點;
  • 輕兒子:父親節點中除了重兒子之外的兒子;
  • 重邊:父親結點和重兒子連成的邊;
  • 輕邊:父親節點和輕兒子連成的邊;
  • 重鏈:由多條重邊鏈接而成的路徑;
  • 輕鏈:由多條輕邊鏈接而成的路徑;

咱們定義樹上一個節點的子節點中子樹最大的一個爲它的重子節點,其他的爲輕子節點。一個節點連向其重子節點的邊稱爲重邊,連向輕子節點的邊則爲輕邊。若是把根節點看做輕的,那麼從每一個輕節點出發,不斷向下走重邊,都對應了一條鏈,因而咱們把樹剖分紅了 \(l\) 條鏈,其中 \(l\) 是輕節點的數量。node

最近由於畫圖工具出了點問題,因此轉載了Pecco學長的示意圖(下面求LCA的方法的部份內容也來自Pecco學長)ios

剖分後的樹(重鏈)有以下性質:c++

  1. 對於節點數爲 \(n\) 的樹,從任意節點向上走到根節點,通過的輕邊數量不會超過 \(log\ n\)算法

    這是由於當咱們向下通過一條 輕邊 時,所在子樹的大小至少會除以二。因此說,對於樹上的任意一條路徑,把它拆分紅從 \(lca\) 分別向兩邊往下走,分別最多走 \(O(\log n)\) 次,樹上的每條路徑均可以被拆分紅不超過 \(O(\log n)\) 條重鏈。數組

  2. 樹上每一個節點都屬於且僅屬於一條重鏈數據結構

重鏈開頭的結點不必定是重子節點(由於重邊是對於每個結點都有定義的)。全部的重鏈將整棵樹 徹底剖分工具

儘管樹鏈部分看起來很難實現(的確有點繁瑣),但咱們能夠用兩個 DFS 來實現樹鏈(樹剖)。學習

相關僞碼(來自 OI wiki)

第一個 DFS 記錄每一個結點的父節點(father)、深度(deep)、子樹大小(size)、重子節點(hson)。

\[\begin{array}{l} \text{TREE-BUILD }(u,dep) \\ \begin{array}{ll} 1 & u.hson\gets 0 \\ 2 & u.hson.size\gets 0 \\ 3 & u.deep\gets dep \\ 4 & u.size\gets 1 \\ 5 & \textbf{for }\text{each }u\text{'s son }v \\ 6 & \qquad u.size\gets u.size + \text{TREE-BUILD }(v,dep+1) \\ 7 & \qquad v.father\gets u \\ 8 & \qquad \textbf{if }v.size> u.hson.size \\ 9 & \qquad \qquad u.hson\gets v \\ 10 & \textbf{return } u.size \end{array} \end{array} \]

第二個 DFS 記錄所在鏈的鏈頂(top,應初始化爲結點自己)、重邊優先遍歷時的 DFS 序(dfn)、DFS 序對應的節點編號(rank)。

\[\begin{array}{l} \text{TREE-DECOMPOSITION }(u,top) \\ \begin{array}{ll} 1 & u.top\gets top \\ 2 & tot\gets tot+1\\ 3 & u.dfn\gets tot \\ 4 & rank(tot)\gets u \\ 5 & \textbf{if }u.hson\text{ is not }0 \\ 6 & \qquad \text{TREE-DECOMPOSITION }(u.hson,top) \\ 7 & \qquad \textbf{for }\text{each }u\text{'s son }v \\ 8 & \qquad \qquad \textbf{if }v\text{ is not }u.hson \\ 9 & \qquad \qquad \qquad \text{TREE-DECOMPOSITION }(v,v) \end{array} \end{array} \]

如下爲代碼實現。

咱們先給出一些定義:

  • \(fa(x)\) 表示節點 \(x\) 在樹上的父親(也就是父節點)。
  • \(dep(x)\) 表示節點 \(x\) 在樹上的深度。
  • \(siz(x)\) 表示節點 \(x\) 的子樹的節點個數。
  • \(son(x)\) 表示節點 \(x\)重兒子
  • \(top(x)\) 表示節點 \(x\) 所在 重鏈 的頂部節點(深度最小)。
  • \(dfn(x)\) 表示節點 \(x\)DFS 序 ,也是其在線段樹中的編號。
  • \(rnk(x)\) 表示 DFS 序所對應的節點編號,有 \(rnk(dfn(x))=x\)

咱們進行兩遍 DFS 預處理出這些值,其中第一次 DFS 求出 \(fa(x)\) , \(dep(x)\) , \(siz(x)\) , \(son(x)\) ,第二次 DFS 求出 \(top(x)\) , \(dfn(x)\) , \(rnk(x)\)

// 固然樹鏈寫法不止一種,這個是我學習Oi wiki上知識點記錄的模板代碼
void dfs1(int o) {
    son[o] = -1, siz[o] = 1;
    for (int j = h[o]; j; j = nxt[j])
        if (!dep[p[j]]) {
            dep[p[j]] = dep[o] + 1;
            fa[p[j]] = o;
            dfs1(p[j]);
            siz[o] += siz[p[j]];
            if (son[o] == -1 || siz[p[j]] > siz[son[o]])
                son[o] = p[j];
        }
}
void dfs2(int o, int t) {
    top[o] = t;
    dfn[o] = ++cnt;
    rnk[cnt] = o;
    if (son[o] == -1)
        return;
    dfs2(son[o], t);  // 優先對重兒子進行 DFS,能夠保證同一條重鏈上的點 DFS 序連續
    for (int j = h[o]; j; j = nxt[j])
        if (p[j] != son[o] && p[j] != fa[o])
            dfs2(p[j], p[j]);
}
// 寫法2:來自Peocco學長,代碼僅做學習使用
void dfs1(int p, int d = 1){
    int Siz = 1,ma = 0;
    dep[p] = d;
    for(auto q : edges[p]){ // for循環寫法和auto是C++11標準,競賽可用
        dfs1(q,d + 1);
        fa[q] = p;
        Siz += sz[q];
        if(sz[q] > ma)
            hson[p] = q, ma = sz[q];// hson = 重兒子
    }
    sz[p] = Siz; 
}
// 須要先把根節點的top初始化爲自身
void dfs2(int p){
    for(auto q : edges[p]){
        if(!top[q]){
            if(q == hson[p])
                top[q] = top[p];
           	else 
                top[q] = q;
            dfs2(q);
        }
    }
}

以上這樣便完成了剖分。

學習到這裏想一想開頭的那句話:

 「在一棵樹上進行路徑的修改、求極值、求和」乍一看只要線段樹就能輕鬆解決,實際上,僅憑線段樹是不能搞定它的。咱們須要用到一種貌似高級的複雜算法——樹鏈剖分。

若是不能一下想不到線段樹解決不了的問題的話不如看看這道題 ↓

Hdu 3966 Aragorn's Story

題目連接:http://acm.hdu.edu.cn/showproblem.php?pid=3966

題意:給一棵樹,並給定各個點權的值,而後有3種操做:
  I C1 C2 K: 把C1與C2的路徑上的全部點權值加上K
  D C1 C2 K:把C1與C2的路徑上的全部點權值減去K
  Q C:查詢節點編號爲C的權值

  分析:典型的樹鏈剖分題目,先進行剖分,而後用線段樹去維護便可

// Author : RioTian
// Time : 20/11/30
#include <bits/stdc++.h>
using namespace std;

#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1

typedef long long ll;
typedef int lld;

stack<int> ss;
const int maxn = 2e5 + 10;
const int inf = ~0u >> 2;  // 1073741823
int M[maxn << 2];
int add[maxn << 2];

struct node {
    int s, t, w, next;
} edges[maxn << 1];

int E, n;
int Size[maxn], fa[maxn], heavy[maxn], head[maxn], vis[maxn];
int dep[maxn], rev[maxn], num[maxn], cost[maxn], w[maxn];
int Seg_size;

int find(int x) {
    return fa[x] == x ? x : fa[x] = find(fa[x]);
}

void add_edge(int s, int t, int w) {
    edges[E].w = w;
    edges[E].s = s;
    edges[E].t = t;
    edges[E].next = head[s];
    head[s] = E++;
}

void dfs(int u, int f) {  //起點,父節點
    int mx = -1, e = -1;
    Size[u] = 1;
    for (int i = head[u]; i != -1; i = edges[i].next) {
        int v = edges[i].t;
        if (v == f)
            continue;
        edges[i].w = edges[i ^ 1].w = w[v];
        dep[v] = dep[u] + 1;
        rev[v] = i ^ 1;
        dfs(v, u);
        Size[u] += Size[v];
        if (Size[v] > mx)
            mx = Size[v], e = i;
    }
    heavy[u] = e;
    if (e != -1)
        fa[edges[e].t] = u;
}

inline void pushup(int rt) {
    M[rt] = M[rt << 1] + M[rt << 1 | 1];
}

void pushdown(int rt, int m) {
    if (add[rt]) {
        add[rt << 1] += add[rt];
        add[rt << 1 | 1] += add[rt];
        M[rt << 1] += add[rt] * (m - (m >> 1));
        M[rt << 1 | 1] += add[rt] * (m >> 1);
        add[rt] = 0;
    }
}

void built(int l, int r, int rt) {
    M[rt] = add[rt] = 0;
    if (l == r)
        return;
    int m = (r + l) >> 1;
    built(lson), built(rson);
}

void update(int L, int R, int val, int l, int r, int rt) {
    if (L <= l && r <= R) {
        M[rt] += val;
        add[rt] += val;
        return;
    }
    pushdown(rt, r - l + 1);
    int m = (l + r) >> 1;
    if (L <= m)
        update(L, R, val, lson);
    if (R > m)
        update(L, R, val, rson);
    pushup(rt);
}

lld query(int L, int R, int l, int r, int rt) {
    if (L <= l && r <= R)
        return M[rt];
    pushdown(rt, r - l + 1);
    int m = (l + r) >> 1;
    lld ret = 0;
    if (L <= m)
        ret += query(L, R, lson);
    if (R > m)
        ret += query(L, R, rson);
    return ret;
}

void prepare() {
    int i;
    built(1, n, 1);
    memset(num, -1, sizeof(num));
    dep[0] = 0;
    Seg_size = 0;
    for (i = 0; i < n; i++)
        fa[i] = i;
    dfs(0, 0);
    for (i = 0; i < n; i++) {
        if (heavy[i] == -1) {
            int pos = i;
            while (pos && edges[heavy[edges[rev[pos]].t]].t == pos) {
                int t = rev[pos];
                num[t] = num[t ^ 1] = ++Seg_size;
                // printf("pos=%d  val=%d t=%d\n", Seg_size, edge[t].w, t);
                update(Seg_size, Seg_size, edges[t].w, 1, n, 1);
                pos = edges[t].t;
            }
        }
    }
}

int lca(int u, int v) {
    while (1) {
        int a = find(u), b = find(v);
        if (a == b)
            return dep[u] < dep[v] ? u : v;  // a,b在同一條重鏈上
        else if (dep[a] >= dep[b])
            u = edges[rev[a]].t;
        else
            v = edges[rev[b]].t;
    }
}

void CH(int u, int lca, int val) {
    while (u != lca) {
        int r = rev[u];  // printf("r=%d\n",r);
        if (num[r] == -1)
            edges[r].w += val, u = edges[r].t;
        else {
            int p = fa[u];
            if (dep[p] < dep[lca])
                p = lca;
            int l = num[r];
            r = num[heavy[p]];
            update(l, r, val, 1, n, 1);
            u = p;
        }
    }
}

void change(int u, int v, int val) {
    int p = lca(u, v);
    // printf("p=%d\n",p);
    CH(u, p, val);
    CH(v, p, val);
    if (p) {
        int r = rev[p];
        if (num[r] == -1) {
            edges[r ^ 1].w += val;  //在此處發現了我代碼的重大bug
            edges[r].w += val;
        } else
            update(num[r], num[r], val, 1, n, 1);
    }  //根節點,特判
    else
        w[p] += val;
}

lld solve(int u) {
    if (!u)
        return w[u];  //根節點,特判
    else {
        int r = rev[u];
        if (num[r] == -1)
            return edges[r].w;
        else
            return query(num[r], num[r], 1, n, 1);
    }
}

int main() {
    // freopen("in.txt", "r", stdin);
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    int t, i, a, b, c, m, ca = 1, p;
    while (cin >> n >> m >> p) {
        memset(head, -1, sizeof(head));
        E = 0;
        for (int i = 0; i < n; ++i)
            cin >> w[i];
        for (int i = 0; i < m; ++i) {
            cin >> a >> b;
            a--, b--;
            add_edge(a, b, 0), add_edge(b, a, 0);
        }
        prepare();  // 預處理
        string op;
        while (p--) {
            cin >> op;
            if (op[0] == 'I') {  //區間添加
                cin >> a >> b >> c;
                a--, b--;
                change(a, b, c);
            } else if (op[0] == 'D') {  //區間減小
                cin >> a >> b >> c;
                a--, b--;
                change(a, b, -c);
            } else {  //查詢
                cin >> a;
                a--;
                cout << solve(a) << endl;
            }
        }
    }
    return 0;
}

因爲數據很大,建議使用快讀,而不是像我同樣用 cin(差了近500ms了)

摺疊代碼是千千dalao的解法:

Code
//千千dalao解法
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 50010;
struct Edge {
    int to;
    int next;
} edge[maxn << 1];
int head[maxn], tot;  //鏈式前向星存儲
int top[maxn];        // v所在重鏈的頂端節點
int fa[maxn];         //父親節點
int deep[maxn];       //節點深度
int num[maxn];        //以v爲根的子樹節點數
int p[maxn];          // v與其父親節點的連邊在線段樹中的位置
int fp[maxn];         //與p[]數組相反
int son[maxn];        //重兒子
int pos;
int w[maxn];
int ad[maxn << 2];  //樹狀數組
int n;              //節點數目
void init() {
    memset(head, -1, sizeof(head));
    memset(son, -1, sizeof(son));
    tot = 0;
    pos = 1;  //由於使用樹狀數組,因此咱們pos初始值從1開始
}
void addedge(int u, int v) {
    edge[tot].to = v;
    edge[tot].next = head[u];
    head[u] = tot++;
}
//第一遍dfs,求出 fa,deep,num,son (u爲當前節點,pre爲其父節點,d爲深度)
void dfs1(int u, int pre, int d) {
    deep[u] = d;
    fa[u] = pre;
    num[u] = 1;
    //遍歷u的鄰接點
    for (int i = head[u]; i != -1; i = edge[i].next) {
        int v = edge[i].to;
        if (v != pre) {
            dfs1(v, u, d + 1);
            num[u] += num[v];
            if (son[u] == -1 || num[v] > num[son[u]])  //尋找重兒子
                son[u] = v;
        }
    }
}
//第二遍dfs,求出 top,p
void dfs2(int u, int sp) {
    top[u] = sp;
    p[u] = pos++;
    fp[p[u]] = u;
    if (son[u] != -1)  //若是當前點存在重兒子,繼續延伸造成重鏈
        dfs2(son[u], sp);
    else
        return;
    for (int i = head[u]; i != -1; i = edge[i].next) {
        int v = edge[i].to;
        if (v != son[u] && v != fa[u])  //遍歷全部輕兒子新建重鏈
            dfs2(v, v);
    }
}
int lowbit(int x) {
    return x & -x;
}
//查詢
int query(int i) {
    int s = 0;
    while (i > 0) {
        s += ad[i];
        i -= lowbit(i);
    }
    return s;
}
//增長
void add(int i, int val) {
    while (i <= n) {
        ad[i] += val;
        i += lowbit(i);
    }
}
void update(int u, int v, int val) {
    int f1 = top[u], f2 = top[v];
    while (f1 != f2) {
        if (deep[f1] < deep[f2]) {
            swap(f1, f2);
            swap(u, v);
        }
        //由於區間減法成立,因此咱們把對某個區間[f1,u]
        //的更新拆分爲 [0,f1] 和 [0,u] 的操做
        add(p[f1], val);
        add(p[u] + 1, -val);
        u = fa[f1];
        f1 = top[u];
    }
    if (deep[u] > deep[v])
        swap(u, v);
    add(p[u], val);
    add(p[v] + 1, -val);
}
int main() {
    ios::sync_with_stdio(false);
    int m, ps;
    while (cin >> n >> m >> ps) {
        int a, b, c;
        for (int i = 1; i <= n; i++)
            cin >> w[i];
        init();
        for (int i = 0; i < m; i++) {
            cin >> a >> b;
            addedge(a, b);
            addedge(b, a);
        }
        dfs1(1, 0, 0);
        dfs2(1, 1);
        memset(ad, 0, sizeof(ad));
        for (int i = 1; i <= n; i++) {
            add(p[i], w[i]);
            add(p[i] + 1, -w[i]);
        }
        for (int i = 0; i < ps; i++) {
            char op;
            cin >> op;
            if (op == 'Q') {
                cin >> a;
                cout << query(p[a]) << endl;
            } else {
                cin >> a >> b >> c;
                if (op == 'D')
                    c = -c;
                update(a, b, c);
            }
        }
    }
    return 0;
}

利用樹鏈求LCA

這個部分參考了Peocco學長,十分感謝

在這道經典題中,求了LCA,但爲何樹剖就能夠求LCA呢?

樹剖能夠單次 \(O(log\ n)\)! 地求LCA,且常數較小。假如咱們要求兩個節點的LCA,若是它們在同一條鏈上,那直接輸出深度較小的那個節點就能夠了。

不然,LCA要麼在鏈頭深度較小的那條鏈上,要麼就是兩個鏈頭的父節點的LCA,但毫不可能在鏈頭深度較大的那條鏈上[1]。因此咱們能夠直接把鏈頭深度較大的節點用其鏈頭的父節點代替,而後繼續求它與另外一者的LCA。

因爲在鏈上咱們能夠 \(O(1)\) 地跳轉,每條鏈間由輕邊鏈接,而通過輕邊的次數又不超過 [公式] ,因此咱們實現了 \(O(log\ n)\) 的LCA查詢。

int lca(int a, int b) {
    while (top[a] != top[b]) {
        if (dep[top[a]] > dep[top[b]])
            a = fa[top[a]];
        else
            b = fa[top[b]];
    }
    return (dep[a] > dep[b] ? b : a);
}

結合數據結構

在進行了樹鏈剖分後,咱們即可以配合線段樹等數據結構維護樹上的信息,這須要咱們改一下第二次 DFS 的代碼,咱們用dfsn數組記錄每一個點的dfs序,用madfsn數組記錄每棵子樹的最大dfs序:(這裏有點像連通圖的知識了)

// 須要先把根節點的top初始化爲自身
int cnt;
void dfs2(int p) {
    madfsn[p] = dfsn[p] = ++cnt;
    if (hson[p] != 0) {
        top[hson[p]] = top[p];
        dfs2(hson[p]);
        madfsn[p] = max(madfsn[p], madfsn[hson[p]]);
    }
    for (auto q : edges[p])
        if (!top[q]) {
            top[q] = q;
            dfs2(q);
            madfsn[p] = max(madfsn[p], madfsn[q]);
        }
}

注意到,每棵子樹的dfs序都是連續的,且根節點dfs序最小;並且,若是咱們優先遍歷重子節點,那麼同一條鏈上的節點的dfs序也是連續的,且鏈頭節點dfs序最小

連通樹(霧)

因此就能夠用線段樹等數據結構維護區間信息(以點權的和爲例),例如路徑修改(相似於求LCA的過程):

void update_path(int x, int y, int z) {
    while (top[x] != top[y]) {
        if (dep[top[x]] > dep[top[y]]) {
            update(dfsn[top[x]], dfsn[x], z);
            x = fa[top[x]];
        } else {
            update(dfsn[top[y]], dfsn[y], z);
            y = fa[top[y]];
        }
    }
    if (dep[x] > dep[y])
        update(dfsn[y], dfsn[x], z);
    else
        update(dfsn[x], dfsn[y], z);
}

路徑查詢:

int query_path(int x, int y) {
    int ans = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] > dep[top[y]]) {
            ans += query(dfsn[top[x]], dfsn[x]);
            x = fa[top[x]];
        } else {
            ans += query(dfsn[top[y]], dfsn[y]);
            y = fa[top[y]];
        }
    }
    if (dep[x] > dep[y])
        ans += query(dfsn[y], dfsn[x]);
    else
        ans += query(dfsn[x], dfsn[y]);
    return ans;
}

子樹修改(更新):

void update_subtree(int x, int z){
    update(dfsn[x], madfsn[x], z);
}

子樹查詢:

int query_subtree(int x){
    return query(dfsn[x], madfsn[x]);
}

須要注意,建線段樹的時候不是按節點編號建,而是按dfs序建,相似這樣:

for (int i = 1; i <= n; ++i)
    B[i] = read();
// ...
for (int i = 1; i <= n; ++i)
    A[dfsn[i]] = B[i];
build();

固然,不只能夠用線段樹維護,有些題也可使用珂朵莉樹等數據結構(要求數據不卡珂朵莉樹,如這道)。此外,若是須要維護的是邊權而不是點權,把每條邊的邊權下放到深度較深的那個節點處便可,可是查詢、修改的時候要注意略過最後一個點。

寫在最後:

OI wiki上有一些推薦作的列題,但每一個都須要比較多的時間+耐心去完成,因此這裏推薦幾個必作的題:

SPOJ QTREE – Query on a tree (樹鏈剖分):千千dalao的題解報告

HDU 3966 Aragorn’s Story (樹鏈剖分):建議先看一遍個人解法再獨立完成。

參考

洛穀日報:https://zhuanlan.zhihu.com/p/41082337

OI wiki:https://oi-wiki.org/graph/hld/

Pecco學長:https://www.zhihu.com/people/one-seventh

千千:https://www.dreamwings.cn/hdu3966/4798.html


  1. 設top[a]的深度≤top[b]的深度,且c=lca(a,b)在b所在的鏈上;那麼c是a和b的祖先且c的深度≥top[b]的深度,那麼c的深度≥top[a]的深度。c是a的祖先,top[a]也是a的祖先,c的深度大於等於top[a],那c必然在鏈接top[a]和a的這條鏈上,與前提矛盾 ↩︎

相關文章
相關標籤/搜索