說到樹的鏈剖,大多數人都會首先想到重鏈剖分。的確,目前重鏈剖分在OI中有更加多樣化的應用,但它大多時候是替代不了長鏈剖分的。c++
重鏈剖分是把size最大的兒子當成重兒子,顧名思義長鏈剖分就是把 len (到葉子節點的距離) 最長的兒子當成重兒子。算法
因爲是和深度有關的算法,長鏈剖分經常使用於優化一些和深度有關的dp或其餘算法。數組
具體按照蒟蒻的理解來講,就是相似啓發式合併的那種感受,每一個點爲根的子樹均可以當作一條最長的鏈上支出了一些叉,而咱們想把這棵子樹捋成只有一條鏈,畢竟鏈多清晰明瞭,還好合並。因而,咱們先遞歸下去重兒子,處理出子樹中這條最長的鏈的信息,而後再枚舉輕兒子遞歸下去,反上來的時候把以這個輕兒子開頭的鏈信息合併到最長鏈上。學習
那麼問題來了,優化
蒟蒻覺得,那種與路徑長度有關,且能以每一個深度爲狀態,相同深度之間更新沒有順序的問題應該就能夠吧。spa
能夠發現,每次向上合併都是拿一條兒子子樹的最長鏈合併到父親子樹的最長鏈上,時空代價都是
$$ \sum_{len(長鏈)}=O(n) $$指針
注意這裏的dp原本是要開二維數組的(f(u)(j)表示u這棵子樹,距離u爲j的點有多少個),想想,爲何空間並不會炸呢?code
能夠發現父親節點u的 f 數組一開始就是重兒子節點v的 f 數組右移一位獲得的,因此咱們能夠最開始把重兒子 f 數組的指針賦成 f[u]+1.遞歸
對於g也是同理。內存
具體實現方式請看代碼及註釋吧。
#include<bits/stdc++.h> using namespace std; #define N 200005 #define ll long long int n,gr,h[N],nxt[N],to[N]; inline void tu(int x,int y){to[++gr]=y,nxt[gr]=h[x],h[x]=gr;} int len[N]{-1},son[N]; void dfs1(int u,int fa){ len[u]=0; for(int i=h[u];i;i=nxt[i]){ if(to[i]==fa)continue; dfs1(to[i],u); if(len[son[u]]<len[to[i]])son[u]=to[i],len[u]=len[son[u]]+1; } } ll mry[N*3],*f[N],*g[N],*tot=mry,ans;//sor // //g(i)表示在子樹外離當前點距離爲i的點能夠和子樹內多少對點組成答案 inline void get_memory(int u,int siz){ f[u]=tot;tot+=(siz<<1)+2;g[u]=tot;tot+=siz+2;//+2是必需的 } void dfs2(int u,int fa){ if(son[u]){ f[son[u]]=f[u]+1;g[son[u]]=g[u]-1; dfs2(son[u],u); } f[u][0]=1; ans+=f[u][0]*g[u][0];//g[u][0]實際上==g[son[u]][1] for(int i=h[u];i;i=nxt[i]){ int d=to[i]; if(d==fa||d==son[u])continue; get_memory(d,len[d]); dfs2(d,u); for(int j=0;j<=len[d];++j){ ans+=f[d][j]*g[u][j+1]; if(j)ans+=g[d][j]*f[u][j-1]; } for(int j=0;j<=len[d];++j){ g[u][j+1]+=f[u][j+1]*f[d][j]; if(j)g[u][j-1]+=g[d][j];//折了一下 f[u][j+1]+=f[d][j]; } } } int main(){ scanf("%d",&n); for(int i=1,x,y;i<n;++i)scanf("%d%d",&x,&y),tu(x,y),tu(y,x); dfs1(1,0); get_memory(1,len[1]); dfs2(1,0); printf("%lld\n",ans); return 0; }
一眼看出來的部分略去不說,咱們須要解決的就是求max{邊數在L和R之間的路徑邊權和}.
因爲轉移時要在一段區間取max,這時候咱們發現咱們轉移的時候須要一棵線段樹,像上一道題同樣動態用指針開內存的方式顯然不能接受。
咱們選擇另外一種方式,即有順序地遍歷全樹得到dfs序,使得每條長鏈上的點dfs序是一段連續的區間。
這樣,dfs序上在不超過len(長鏈)的範圍下加多少,就是深度向下走多少。即可以輕鬆轉移了。
#include<bits/stdc++.h> using namespace std; #define il inline #define rep(i,a,b) for(register int i=(a);i<=(b);++i) #define dwn(i,a,b) for(register int i=(a);i>=(b);--i) #define lc (x<<1) #define rc (x<<1|1) typedef double db; typedef long long ll; const int N = 200005; const db eps = 1e-5; int n,gr,h[N],nxt[N],to[N],w[N],lwr,upp; inline void tu(int x,int y,int v){to[++gr]=y,nxt[gr]=h[x],h[x]=gr,w[gr]=v;} int len[N],son[N];// int dfn[N],tim,fa[N],sonw[N]; struct Tre{ db tr[N<<2]; void modify(int p,int L,int R,db v,int x){ if(L==p&&R==p){tr[x]=max(tr[x],v);return;}//注意要不斷取max! 而不是賦值! int mid=(L+R)>>1; if(p<=mid)modify(p,L,mid,v,lc); else modify(p,mid+1,R,v,rc); tr[x]=max(tr[lc],tr[rc]); } db query(int l,int r,int L,int R,int x){ if(l==L&&r==R)return tr[x]; int mid=(L+R)>>1; if(r<=mid)return query(l,r,L,mid,lc); else if(l>mid)return query(l,r,mid+1,R,rc); else return max(query(l,mid,L,mid,lc),query(mid+1,r,mid+1,R,rc)); } }T; void dfs1(int u,int f){ len[u]=0;fa[u]=f; for(int i=h[u];i;i=nxt[i]){ if(to[i]==f)continue; dfs1(to[i],u); if(len[son[u]]<len[to[i]])son[u]=to[i],len[u]=len[son[u]]+1,sonw[u]=w[i]; } } db ans; void dfs2(int u){ dfn[u]=++tim; if(!son[u])return; dfs2(son[u]); for(int i=h[u];i;i=nxt[i]){ if(to[i]==fa[u]||to[i]==son[u])continue; dfs2(to[i]); } } db dis[N],now[N]; void solve(int u,db x){ T.modify(dfn[u],1,n,dis[u],1); if(son[u]){ dis[son[u]]=dis[u]+sonw[u]-x; solve(son[u],x); } for(int i=h[u];i;i=nxt[i]){ if(to[i]==son[u]||to[i]==fa[u])continue; int d=to[i]; dis[d]=dis[u]+w[i]-x; solve(d,x); rep(j,0,len[d]){ now[j]=T.query(dfn[d]+j,dfn[d]+j,1,n,1); if(j+1<=upp&&j+len[u]+1>=lwr){ ans=max(ans,now[j]+T.query(dfn[u]+max(0,lwr-j-1),dfn[u]+min(upp-j-1,len[u]),1,n,1)-2*dis[u]); } } rep(j,0,len[d]){ T.modify(dfn[u]+j+1,1,n,now[j],1); } } if(len[u]>=lwr) ans=max(ans,T.query(dfn[u]+lwr,dfn[u]+min(len[u],upp),1,n,1)-dis[u]); // } bool check(db x){ memset(T.tr,0xc2,sizeof(T.tr)); ans=-1e7; solve(1,x); return (ans>=0); } int main(){ scanf("%d",&n); scanf("%d%d",&lwr,&upp); int a,b,c,mx=0; rep(i,1,n-1)scanf("%d%d%d",&a,&b,&c),tu(a,b,c),tu(b,a,c),mx=max(mx,c); len[0]=-1; dfs1(1,0),dfs2(1); db l=0,r=mx,mid; while(l+eps<r){ mid=(l+r)/2.0; if(check(mid))l=mid; else r=mid; } if(check(r))printf("%.3lf\n",r); else printf("%.3lf\n",l); return 0; }