DFS序+線段樹 hihoCoder 1381 Little Y's Tree(樹的連通塊的直徑和)

 

題目連接c++

#1381 : Little Y's Tree

時間限制:24000ms
單點時限:4000ms
內存限制:512MB

描述

小Y有一棵n個節點的樹,每條邊都有正的邊權。ui

小J有q個詢問,每次小J會刪掉這個樹中的k條邊,這棵樹被分紅k+1個連通塊。小J想知道每一個連通塊中最遠點對距離的和。spa

這裏的詢問是互相獨立的,即每次都是在小Y的原樹上進行操做。code

輸入

第一行一個整數n,接下來n-1行每行三個整數u,v,w,其中第i行表示第i條邊邊權爲wi,鏈接了ui,vi兩點。blog

接下來一行一個整數q,表示有q組詢問。內存

對於每組詢問,第一行一個正整數k,接下來一行k個不一樣的1到n-1之間的整數,表示刪除的邊的編號。get

1<=n,q,Σk<=105, 1<=w<=109it

輸出

共q行,每行一個整數表示詢問的答案。class

 

題解:方法

  首先考慮給出兩個點集,如何求這兩個點集合並以後的直徑,方法是把兩個點集的直徑分別求出來,而後對於這4個點,求出兩兩之間距離的最大值。

  因而能夠按dfs序創建線段樹,而後求出每一個區間的直徑。

  而對於一個詢問,刪掉k條邊,每棵子樹都對應的dfs序中的若干區間,並且區間總個數不會超過2k,對於每一個區間能夠在線段樹中查詢。

  時間複雜度O(nlog^2n)。

 代碼:

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int N = 1e5 + 5;
const int D = 20;

struct Edge {
    int u, v, w;
};

int L[N], R[N], p[N], rt[N][D], dep[N];
ll d[N];
int dfs_clock;
vector<Edge> edges;
vector<int> id[N];
int n, m;

void init_edge() {
    edges.clear();
    for (int i=1; i<=n; ++i) id[i].clear();
    m = 0;
}

void add_edge(int u, int v, int w) {
    edges.push_back((Edge){u, v, w});
    m = edges.size();
    id[u].push_back(m-1);
}

void DFS(int u, int fa) {
    L[u] = ++dfs_clock;
    p[dfs_clock] = u;
    dep[u] = dep[fa] + 1;
    rt[u][0] = fa;
    for (int i: id[u]) {
        Edge &e = edges[i];
        if (e.v == fa) continue;
        d[e.v] = d[u] + e.w;
        DFS(e.v, u);
    }
    R[u] = dfs_clock;
}

void init_LCA() {
    for (int j=1; j<D; ++j) {
        for (int i=1; i<=n; ++i) {
            rt[i][j] = rt[rt[i][j-1]][j-1];
        }
    }
}

int LCA(int u, int v) {
    if (dep[u] < dep[v]) swap(u, v);
    for (int i=0; i<D; ++i) {
        if ((dep[u]-dep[v]) >> i & 1) u = rt[u][i];
    }
    if (u == v) return u;
    for (int i=D-1; i>=0; --i) {
        if (rt[u][i] != rt[v][i]) {
            u = rt[u][i];
            v = rt[v][i];
        }
    }
    return rt[u][0];
}

ll dis(int u, int v) {
    return d[u] + d[v] - 2 * d[LCA(u, v)];
}

struct Node {
    ll d;
    int a, b;
    Node(ll d=0, int a=0, int b=0) : d(d), a(a), b(b) {}
    bool operator < (const Node &rhs) const {
        return d < rhs.d;
    }
};

Node nd[N<<2];

Node better(Node x, Node y) {
    if (x.d == -1) return y;
    if (y.d == -1) return x;
    Node z1 = Node(dis(x.a, y.a), x.a, y.a);
    Node z2 = Node(dis(x.a, y.b), x.a, y.b);
    Node z3 = Node(dis(x.b, y.a), x.b, y.a);
    Node z4 = Node(dis(x.b, y.b), x.b, y.b);
    return max({x, y, z1, z2, z3, z4});
}

#define lch o << 1
#define rch o << 1 | 1

void build(int o, int l, int r) {
    if (l == r) {
        nd[o] = Node(0, p[l], p[l]);
        return ;
    }
    int mid = l + r >> 1;
    build(lch, l, mid);
    build(rch, mid+1, r);
    nd[o] = better(nd[lch], nd[rch]);
}

Node query(int o, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) {
        return nd[o];
    }
    int mid = l + r >> 1;
    Node ret = Node(-1, 0, 0);
    if (ql <= mid) ret = better(ret, query(lch, l, mid, ql, qr));
    if (qr > mid) ret = better(ret, query(rch, mid+1, r, ql, qr));
    return ret;
}

bool cmp(int a, int b) {
    return L[a] < L[b];
}

int x[N], s[N];
vector<int> edge[N];
ll ans;

void prepare() {
    dfs_clock = 0;
    dep[0] = 0;
    d[1] = 0;
    
    DFS(1, 0);
    init_LCA();
    build(1, 1, n);
}

void DFS(int u) {
    Node res = Node(-1, x[u], x[u]);
    int ql = L[x[u]], qr = R[x[u]];
    for (int v: edge[u]) {
        DFS(v);
        res = better(res, query(1, 1, n, ql, L[x[v]]-1));
        ql = R[x[v]] + 1;
    }
    res = better(res, query(1, 1, n, ql, qr));
    ans += res.d;
}

int main() {
    scanf("%d", &n);
    init_edge();
    for (int i=1; i<n; ++i) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        add_edge(u, v, w);
        add_edge(v, u, w);
    }
    
    prepare();

    int q, k;
    scanf("%d", &q);
    while (q--) {
        scanf("%d", &k);
        int idx;
        for (int i=1; i<=k; ++i) {
            scanf("%d", &idx);
            idx--;
            Edge &e = edges[idx*2];
            if (dep[e.u] > dep[e.v]) x[i] = e.u;
            else x[i] = e.v;
        }
        sort(x+1, x+1+k, cmp);

        x[0] = 1;
        int nn = 1;
        s[nn] = 0;
        for (int i=0; i<=k; ++i) edge[i].clear();
        //s[nn]:0~k, x[s[n]]:1 or x[1~k]
        for (int i=1; i<=k; ++i) {
            while (!(L[x[s[nn]]] <= L[x[i]] && R[x[i]] <= R[x[s[nn]]])) nn--;
            edge[s[nn]].push_back(i);
            s[++nn] = i;
        }

        ans = 0;
        DFS(0);
        printf("%lld\n", ans);
    }
    return 0;
}
相關文章
相關標籤/搜索