P4178 Tree(點分治)

題面要求小於等於K的路徑數目,我麼很天然的想到點分治(不會的就戳我)c++

這道題的統計答案與模板題不同的地方是由等於K到小於等於Kspa

那麼咱們能夠把每個子節點到當前根(重心)的距離排序,而後用相似雙指針的方法來求小於等於K的邊的數量指針

可是若是隻是雙指針統計的話,那麼如下不合法的狀況顯然也會被算進答案:code

QWQ

而咱們須要的合法路徑是長成這樣的:blog

QAQ

因此咱們須要減去上述不合法的路徑,怎麼減呢?排序

不難發現,對於全部不合法的路徑,都是在當前跟的某一棵子樹上的(沒有跨越兩個子樹)get

因此咱們能夠對當前跟節點的每一條邊進行遍歷,利用容斥的思想減去不合法的路徑。it

具體操做爲:當遍歷重心節點的每個節點時,咱們能夠從新計算dis,而後把通過了從重心到新遍歷的點的邊兩次的路徑剪掉(就是上述不合法路徑),最後統計答案便可io

#include<bits/stdc++.h>
using namespace std;
#define il inline
#define re register
#define inf 123456789
il int read()
{
    re int x = 0, f = 1; re char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
    return x * f;
}
#define rep(i, s, t) for(re int i = s; i <= t; ++ i)
#define drep(i, s, t) for(re int i = t; i >= s; -- i)
#define Next(i, u) for(re int i = head[u]; i; i = e[i].next)
#define mem(k, p) memset(k, p, sizeof(k))
#define maxn 40005
struct edge{int v, w, next;}e[maxn << 1];
int n, m, head[maxn], cnt, k, ans;
int dp[maxn], vis[maxn], sum, dis[maxn], rt, size[maxn], rev[maxn], tot;
il void add(int u, int v, int w)
{
    e[++ cnt] = (edge){v, w, head[u]}, head[u] = cnt;
    e[++ cnt] = (edge){u, w, head[v]}, head[v] = cnt;
}
il void getrt(int u, int fr)
{
    size[u] = 1, dp[u] = 0;
    Next(i, u)
    {
        int v = e[i].v;
        if(v == fr || vis[v]) continue;
        getrt(v, u);
        size[u] += size[v], dp[u] = max(dp[u], size[v]);
    }
    dp[u] = max(dp[u], sum - size[u]);
    if(dp[u] < dp[rt]) rt = u;
}
il void getdis(int u, int fr)
{
    rev[++ tot] = dis[u];
    Next(i, u)
    {
        int v = e[i].v;
        if(v == fr || vis[v]) continue;
        dis[v] = dis[u] + e[i].w;
        getdis(v, u);
    }
}
il int doit(int u, int w)
{
    tot = 0, dis[u] = w, getdis(u, 0);
    sort(rev + 1, rev + tot + 1);
    int l = 1, r = tot, ans = 0;
    while(l <= r) (rev[l] + rev[r] <= k) ? (ans += r - l, ++ l) : (-- r);
    return ans;
}
il void solve(int u)
{
    vis[u] = 1, ans += doit(u, 0);
    Next(i, u)
    {
        int v = e[i].v;
        if(vis[v]) continue;
        ans -= doit(v, e[i].w);
        sum = size[v], dp[0] = n, rt = 0;
        getrt(v, u), solve(rt);
    }
}
int main()
{
    n = read();
    rep(i, 1, n - 1){int u = read(), v = read(), w = read(); add(u, v, w);}
    k = read(), dp[0] = sum = n, getrt(1, 0), solve(rt);
    printf("%d", ans);
    return 0;
}
相關文章
相關標籤/搜索