樹鏈剖分入門

簡介

樹鏈剖分就是將樹劃分爲多條鏈,將每條鏈映射到序列上,而後使用線段樹,平衡樹等數據結構來維護每條鏈的信息。c++

樹剖將樹鏈映射到序列上後用線段樹等數據結構來維護樹鏈信息,因此能夠像區間修改,區間查詢同樣進行樹上路徑的修改、查詢等操做數據結構

基本概念

重兒子:子樹結點數目最多的兒子(size最大的點);ui

重邊:父親結點和重兒子連成的邊;spa

重鏈:由多條重邊鏈接而成的路徑;3d

輕兒子: 除了重兒子,其他都爲輕兒子;code

輕邊:重邊以外的邊;blog

紅圈表示重兒子;ip

黑邊表示重邊;input

由黑邊連成的鏈即爲重鏈it

實現

第一遍dfs處理出重兒子,深度,父親等信息

第二遍dfs處理出結點所在重鏈的鏈頂,dfs序

上圖中的樹處理完畢後dfs序以下

能夠發現,因爲咱們優先dfs重兒子,因此重兒子結點的編號是連續的,因而一條重鏈就被映射成了一段連續的區間;

這樣樹上兩點間的路徑就被分割成了多個連續的區間,如

(12,8) 可分割成 12,2 - 6, 1 – 4, 8 四段

(11,13)可分紅 2-6-11,1 – 4 -9 -13 兩段

因而咱們就可使用線段樹來進行樹鏈修改與查詢了

時間複雜度

每次都將路徑分割成多個區間,區間操做能夠在O(logn)內解決, 但若是分割成的區間數不少怎麼辦?!

那咱們就來看下兩點間的路徑最多會分割成多少段。

這時候重兒子就發揮做用了,因爲每條鏈的鏈頂都是一個輕兒子,輕兒子的大小確定小於重兒子, 因此size[輕兒子]<=size[父親]/2

這樣從上往下每進入一條新鏈,結點的個數就會除2,因此通過的鏈數就是log級別的了。

樹剖每次將路徑分紅log個區間,而後區間操做通常都會用到線段樹之類的數據結構來維護,因此通常狀況下一次操做的時間複雜度爲(logn)^2

例題

樹的統計

Description

一棵樹上有n個節點,編號分別爲1到n,每一個節點都有一個權值w。咱們將如下面的形式來要求你對這棵樹完成一些操做:

I. CHANGE u t : 把結點u的權值改成t

II. QMAX u v: 詢問從點u到點v的路徑上的節點的最大權值 I

II. QSUM u v: 詢問從點u到點v的路徑上的節點的權值和

注意:從點u到點v的路徑上的節點包括u和v自己

Input

輸入的第一行爲一個整數n,表示節點的個數。接下來n – 1行,每行2個整數a和b,表示節點a和節點b之間有一條邊相連。接下來n行,每行一個整數,第i行的整數wi表示節點i的權值。接下來1行,爲一個整數q,表示操做的總數。接下來q行,每行一個操做,以「CHANGE u t」或者「QMAX u v」或者「QSUM u v」的形式給出。

對於100%的數據,保證1<=n<=30000,0<=q<=200000;中途操做中保證每一個節點的權值w在-30000到30000之間。

Output

對於每一個「QMAX」或者「QSUM」的操做,每行輸出一個整數表示要求輸出的結果。

Sample Input

4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

Sample Output

4
1
2
2
10
6
5
6
5
16

題解

樹剖模板題,樹鏈剖分後用線段樹維護便可

#include <bits/stdc++.h>
#define lson (o << 1)
#define rson (o << 1 | 1)
using namespace std;
const int N = 3e4 + 10;
typedef long long ll;
vector<int> G[N];
const ll inf = 1e9;
int n;
int val[N];
int fa[N];
int son[N];
int sze[N];
int dep[N];
void dfs1(int u, int f) {
    sze[u] = 1;
    fa[u] = f;
    son[u] = 0;
    dep[u] = dep[f] + 1;
    for (int i = 0; i < G[u].size(); i++) {
        int v = G[u][i];
        if (v == f) continue;
        dfs1(v, u);
        sze[u] += sze[v];
        if (sze[v] > sze[son[u]]) son[u] = v;
    }
}
int top[N];
int cnt;
int pos[N];
int a[N];
void dfs2(int u, int f, int t) {
    top[u] = t;
    pos[u] = ++cnt;
    a[cnt] = val[u];
    if (son[u]) dfs2(son[u], u, t);
    for (int i = 0; i < G[u].size(); i++) {
        int v = G[u][i];
        if (v == f || v == son[u]) continue;
        dfs2(v, u, v);
    }
}
ll sumv[N << 2];
ll maxv[N << 2];
void pushup(int o) {
    sumv[o] = sumv[lson] + sumv[rson];
    maxv[o] = max(maxv[lson], maxv[rson]);
}
void build(int o, int l, int r) {
    if (l == r) {
        sumv[o] = a[l];
        maxv[o] = a[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(lson, l, mid); build(rson, mid + 1, r);
    pushup(o);
}
void update(int o, int l, int r, int pos, ll v) {
    if (l == r) {
        sumv[o] = v;
        maxv[o] = v;
        return;
    }
    int mid = (l + r) >> 1;
    if (pos <= mid) update(lson, l, mid, pos, v);
    else update(rson, mid + 1, r, pos, v);
    pushup(o);
}
ll querysum(int o, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) {
        return sumv[o];
    }
    ll ans = 0; int mid = (l + r) >> 1;
    if (ql <= mid) ans += querysum(lson, l, mid, ql, qr);
    if (qr > mid) ans += querysum(rson, mid + 1, r, ql, qr);
    return ans;
}
ll querymax(int o, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) {
        return maxv[o];
    }
    ll ans = -inf; int mid = (l + r) >> 1;
    if (ql <= mid) ans = max(ans, querymax(lson, l, mid, ql, qr));
    if (qr > mid) ans = max(ans, querymax(rson, mid + 1, r, ql, qr));
    return ans;
}
ll calcsum(int u, int v) {
    ll ans = 0;
    while (top[u] != top[v]) {//當不在同一條鏈上
        if (dep[top[u]] < dep[top[v]]) swap(u, v);//每次深度較大的點向上走
        ans += querysum(1, 1, n, pos[top[u]], pos[u]);
        u = fa[top[u]];//進入新的鏈
    }
    if (dep[u] < dep[v]) swap(u, v);//進入同一條鏈再求一次
    ans += querysum(1, 1, n, pos[v], pos[u]);
    return ans;
}
ll calcmax(int u, int v) {
    ll ans = -inf;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        ans = max(ans, querymax(1, 1, n, pos[top[u]], pos[u]));
        u = fa[top[u]];
    }
    if (dep[u] < dep[v]) swap(u, v);
    ans = max(ans, querymax(1, 1, n, pos[v], pos[u]));
    return ans;
}
int main() {
    //freopen("in.txt", "r", stdin);
    //freopen("out.txt", "w", stdout);
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    for (int i = 1; i <= n; i++) scanf("%d", &val[i]);
    dep[0] = 0;
    dfs1(1, 0);
    cnt = 0;
    dfs2(1, 0, 1);
    build(1, 1, n);
    int m;
    scanf("%d", &m);
    char ch[10];
    for (int i = 1; i <= m; i++) {
        scanf("%s", ch);
        int l, r, k;
        ll v;
        switch(ch[1]) {
            case 'M': scanf("%d%d", &l, &r); printf("%lld\n", calcmax(l, r)); break;
            case 'S': scanf("%d%d", &l, &r); printf("%lld\n", calcsum(l, r)); break;
            case 'H': scanf("%d%lld", &k, &v); update(1, 1, n, pos[k], v); break;
        }
    }
    return 0;
}
相關文章
相關標籤/搜索