原文連接www.cnblogs.com/zhouzhendong/p/UOJ470.htmlhtml
作完情報中心來看這個題忽然發現兩題有類似之處而後就會作了。c++
首先,咱們考慮將全部答案點對分爲兩類。git
第一種狀況很是簡單,這裏不加贅述。spa
對於第二種狀況,咱們首先考慮簡單作法:code
考慮對於每個節點分開處理。htm
按照某一種順序枚舉它的子樹,對於全部「一端在當前子樹內,另外一端在當前子樹以前的子樹」的路徑,咱們求它們的貢獻。blog
接下來提到的「虛樹「中默認加入當前節點。get
考慮對當前子樹內路徑端點創建虛樹,而後在虛樹上 dfs。對於虛樹上的一個節點,它在另一個子樹中有相同語言的節點就是它在虛樹上的子樹中的全部端點的另外一端點構成的虛樹大小。it
一個節點的子樹中全部端點對應的點構成的虛樹能夠由兒子節點的虛樹合併而來。class
若是事先將虛樹內的節點存在 set 中,則能夠在關於點數較少的虛樹的複雜度內合併兩棵虛樹,具體地說是 size * log(n) 。
考慮使用 DSU on tree,咱們能夠獲得一個 $O(n\log ^ 3n)$ 的作法。
注意到,在不少問題裏,線段樹合併均可以處理樹上啓發式合併的問題,並且複雜度都會降低。這裏也相似,考慮合併兩個 dfs序 分別獨立的虛樹時,只須要特殊考慮 dfs序 小的虛樹的 dfs序最大節點和 dfs序 大的虛樹的dfs序最小節點到根的路徑交便可。
因而,咱們考慮採用線段樹合併維護子樹虛樹 size,因爲線段樹合併中須要求 LCA,因此咱們考慮用 ST表 來求 LCA,作到單次詢問 $O(1)$,便可獲得一個總時間複雜度 $O((n+m)\log n)$ 的作法。
#include <bits/stdc++.h> #define clr(x) memset(x,0,sizeof x) #define For(i,a,b) for (int i=(a);i<=(b);i++) #define Fod(i,b,a) for (int i=(b);i>=(a);i--) #define fi first #define se second #define pb(x) push_back(x) #define mp(x,y) make_pair(x,y) #define outval(x) cerr<<#x" = "<<x<<endl #define outtag(x) cerr<<"---------------"#x"---------------"<<endl #define outarr(a,L,R) cerr<<#a"["<<L<<".."<<R<<"] = ";\ For(_x,L,R)cerr<<a[_x]<<" ";cerr<<endl; using namespace std; typedef long long LL; LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } const int N=100005*2; int n,m; vector <int> e[N]; struct cha{ int x,y,lca; int xf,yf; }a[N]; int depth[N],fa[N][20]; int ett[N],c=0,I[N]; void dfs(int x,int pre,int d){ depth[x]=d,fa[x][0]=pre; For(i,1,19) fa[x][i]=fa[fa[x][i-1]][i-1]; ett[I[x]=++c]=x; for (int y : e[x]) if (y!=pre) dfs(y,x,d+1),ett[++c]=x; } int st[N][20],Log[N]; int min_dep(int x,int y){ return depth[x]<depth[y]?x:y; } void Get_ST(){ For(i,2,c) Log[i]=Log[i>>1]+1; For(i,1,c){ st[i][0]=ett[i]; For(j,1,19){ st[i][j]=st[i][j-1]; if (i-(1<<(j-1))>0) st[i][j]=min_dep(st[i][j],st[i-(1<<(j-1))][j-1]); } } } int LCA(int x,int y){ x=I[x],y=I[y]; if (x>y) swap(x,y); int d=Log[y-x+1]; return min_dep(st[x+(1<<d)-1][d],st[y][d]); } int Dis(int x,int y){ return depth[x]+depth[y]-2*depth[LCA(x,y)]; } namespace Seg{ const int S=N*20*2; int sz[S],lp[S],rp[S],ls[S],rs[S]; int cnt=0; void pushup(int rt){ if (!sz[ls[rt]]&&!sz[rs[rt]]) sz[rt]=lp[rt]=rp[rt]=0; else if (!sz[rs[rt]]) sz[rt]=sz[ls[rt]],lp[rt]=lp[ls[rt]],rp[rt]=rp[ls[rt]]; else if (!sz[ls[rt]]) sz[rt]=sz[rs[rt]],lp[rt]=lp[rs[rt]],rp[rt]=rp[rs[rt]]; else { sz[rt]=sz[ls[rt]]+sz[rs[rt]]-depth[LCA(rp[ls[rt]],lp[rs[rt]])]; lp[rt]=lp[ls[rt]],rp[rt]=rp[rs[rt]]; } } void Ins(int &rt,int L,int R,int x){ if (!rt) rt=++cnt,sz[rt]=ls[rt]=rs[rt]=lp[rt]=rp[rt]=0; if (L==R){ lp[rt]=rp[rt]=x,sz[rt]=depth[x]; return; } int mid=(L+R)>>1; if (I[x]<=mid) Ins(ls[rt],L,mid,x); else Ins(rs[rt],mid+1,R,x); pushup(rt); } int Merge(int x,int y,int L,int R){ if (!x||!y) return x|y; if (L==R) return x; int mid=(L+R)>>1,rt=++cnt; ls[rt]=Merge(ls[x],ls[y],L,mid); rs[rt]=Merge(rs[x],rs[y],mid+1,R); pushup(rt); return rt; } } int go_son(int x,int f){ Fod(i,19,0) if (depth[x]-(1<<i)>depth[f]) x=fa[x][i]; return x; } LL ans=0; vector <int> qid[N]; int up[N]; bool cmp_qid(int x,int y){ return I[a[x].xf]<I[a[y].xf]; } bool cmpI(int x,int y){ return I[x]<I[y]; } int rt[N]; void Solve(int x,int *id,int n){ static int t[N],st[N]; int tc=0,top=0; For(i,0,n-1) t[++tc]=a[id[i]].x; t[++tc]=x; sort(t+1,t+tc+1,cmpI); tc=unique(t+1,t+tc+1)-t-1; For(i,1,tc) rt[t[i]]=0; For(i,0,n-1) Seg::Ins(rt[a[id[i]].x],1,c,a[id[i]].y); For(_,1,tc){ int i=t[_]; if (top){ int lca=LCA(i,st[top]); while (depth[st[top]]>depth[lca]){ int now=st[top]; if (depth[st[top-1]]>=depth[lca]){ ans+=(LL)(depth[now]-depth[st[top-1]])*(Seg::sz[rt[now]]-depth[x]); rt[st[top-1]]=Seg::Merge(rt[st[top-1]],rt[now],1,c); top--; } else { ans+=(LL)(depth[now]-depth[lca])*(Seg::sz[rt[now]]-depth[x]); rt[lca]=rt[now]; st[top]=lca; break; } } } st[++top]=i; } while (top>1){ int now=st[top]; ans+=(LL)(depth[now]-depth[st[top-1]])*(Seg::sz[rt[now]]-depth[x]); rt[st[top-1]]=Seg::Merge(rt[st[top-1]],rt[now],1,c); top--; } } void Solve(int x,int pre){ for (int y : e[x]) if (y!=pre) Solve(y,x),up[x]=max(up[x],up[y]-1); ans+=up[x]; sort(qid[x].begin(),qid[x].end(),cmp_qid); int s=(int)qid[x].size(); for (int i=0,j;i<s;i=j+1){ for (j=i;j+1<s&&I[a[qid[x][i]].xf]==I[a[qid[x][j+1]].xf];j++); Solve(x,&qid[x][i],j-i+1); } } int main(){ n=read(),m=read(); For(i,1,n-1){ int x=read(),y=read(); e[x].pb(y),e[y].pb(x); } dfs(1,0,0); Get_ST(); For(i,1,m){ int x=a[i].x=read(),y=a[i].y=read(),lca=a[i].lca=LCA(x,y); up[x]=max(up[x],depth[x]-depth[lca]); up[y]=max(up[y],depth[y]-depth[lca]); if (x!=lca&&y!=lca){ a[i].xf=go_son(x,lca); a[i].yf=go_son(y,lca); if (I[a[i].xf]<I[a[i].yf]) swap(a[i].xf,a[i].yf),swap(a[i].x,a[i].y); qid[lca].pb(i); } } Solve(1,0); cout<<ans<<endl; return 0; }