bzoj 2152(點分治)

傳送門php

題意:

給你一棵樹,每條邊有權值,如今要求你求出有多少對點對知足他們之間的權值可以被\(3\)整除。c++

分析:

由於咱們須要維護的是樹中鏈的信息,所以咱們不妨使用點分治進行統計。由於咱們只須要判斷是否可以被\(3\)整除,所以咱們只須要統計出點對之間的權值和\(\%3\)以後的信息,即咱們只須要統計出\(\%3=0\)\(\%3=1\)\(\%3=2\)的信息便可,咱們計做\(cnt[i]\)spa

對於不在一棵子樹上的兩點,顯然權值\(\%3=1\)的點和權值\(\%3=2\)的能夠分別被貢獻答案,答案爲\(2*cnt[1]*cnt[2]\),同理兩個權值\(\%3=0\)的點也能夠貢獻答案,所以對於不在一棵子樹上的兩點的貢獻爲\(cnt[1]*cnt[2]*2+cnt[0]*cnt[0]\)code

而在同一棵子樹上的兩點由於會加多到他們\(lca\)的權值貢獻,因此須要剪掉這部分的貢獻。get

總體的時間複雜度爲:\(\mathcal{O}(nlogn)\)it

代碼:

#include <bits/stdc++.h>
#define maxn 20005
using namespace std;
struct Node{
    int to,next,val;
}q[maxn<<1];
int head[maxn],cnt=0;
int siz[maxn],Size,dep[maxn],root,minr,CNT[10],res=0,n;
bool vis[maxn];
void init(){
    memset(head,-1, sizeof(head));
    cnt=0;
}
void add_edge(int from,int to,int val){
    q[cnt].to=to;
    q[cnt].val=val;
    q[cnt].next=head[from];
    head[from]=cnt++;
}
void getroot(int x,int fa){
    siz[x]=1;
    int ret=0;
    for(int i=head[x];i!=-1;i=q[i].next){
        int to=q[i].to;
        if(to==fa||vis[to]) continue;
        getroot(to,x);
        ret=max(ret,siz[to]);
        siz[x]+=siz[to];
    }
    ret=max(ret,Size-siz[x]);
    if(ret<minr) minr=ret,root=x;
}
void getdep(int x,int fa){
    CNT[dep[x]]++;
    for(int i=head[x];i!=-1;i=q[i].next){
        int to=q[i].to;
        if(to==fa||vis[to]) continue;
        dep[to]=(dep[x]+q[i].val)%3;
        getdep(to,x);
    }
}
int cal(int x,int pre){
    dep[x]=pre;
    memset(CNT,0,sizeof(CNT));
    getdep(x,x);
    return CNT[0]*CNT[0]+CNT[1]*CNT[2]*2;
}
int dfs(int x){
    res+=cal(x,0);
    vis[x]=1;
    for(int i=head[x];i!=-1;i=q[i].next){
        int to=q[i].to;
        if(vis[to]) continue;
        res-=cal(to,q[i].val);
        minr=n,Size=siz[to];
        getroot(to,-1);
        dfs(root);
    }
}
int main()
{
    scanf("%d",&n);
    init();
    for(int i=1;i<n;i++){
        int from,to,val;
        scanf("%d%d%d",&from,&to,&val);
        add_edge(from,to,val%3);
        add_edge(to,from,val%3);
    }
    Size=minr=n;
    getroot(1,-1);
    dfs(root);
    int tmp=n*n;
    int gcd=__gcd(tmp,res);
    printf("%d/%d\n",res/gcd,tmp/gcd);
    return 0;
}
相關文章
相關標籤/搜索