號稱是noip2016最噁心的題html
基本上用了一天來搞明白+給sy講明白(可能還沒講明白node
具體思路是真的不想寫了(快吐了c++
若是要看,參見洛谷P1600 每天愛跑步——題解spa
雖然這樣很差但我真的不想寫了code
直接放代碼:htm
#include<bits/stdc++.h> #include<vector> using namespace std; inline int read() { int ans=0; char last=' ',ch=getchar(); while(ch>'9'||ch<'0') last=ch,ch=getchar(); while(ch>='0'&&ch<='9') ans=(ans<<1)+(ans<<3)+ch-'0',ch=getchar(); if(last=='-') ans=-ans; return ans; } const int mxn=300000; int n,m; vector<int> lcau[mxn],len[mxn]; struct nd{ int i,v; }; vector<nd> lcav[mxn]; int dis[mxn]; struct node { int to,nxt; }e[mxn<<1]; int w[mxn]; int ans[mxn]; int ecnt,head[mxn]; void add(int from,int to) { ++ecnt; e[ecnt].to=to; e[ecnt].nxt=head[from]; head[from]=ecnt; } int dep[mxn],fa[mxn],siz[mxn],son[mxn],top[mxn]; void dfs1(int u,int f) { dep[u]=dep[f]+1; fa[u]=f; siz[u]=1; int maxn=-1; for(int i=head[u],v;i;i=e[i].nxt) { v=e[i].to; if(v==f) continue; dfs1(v,u); siz[u]+=siz[v]; if(maxn<siz[v]) { maxn=siz[v]; son[u]=v; } } } void dfs2(int u,int f) { if(son[f]==u) top[u]=top[f]; else top[u]=u; if(son[u]) dfs2(son[u],u); for(int i=head[u],v;i;i=e[i].nxt) { v=e[i].to; if(v==f||v==son[u]) continue; dfs2(v,u); } } int lca(int x,int y) { while(top[x]!=top[y]) { if(dep[top[x]]>dep[top[y]]) x=fa[top[x]]; else y=fa[top[y]]; } if(x==y) return x; if(dep[x]>dep[y]) swap(x,y); return x; } int t1[mxn],t2[mxn<<1]; int st[mxn]; void dfs(int u) { int bf=t1[dep[u]+w[u]]+t2[w[u]-dep[u]+mxn]; for(int i=head[u],v;i;i=e[i].nxt) { v=e[i].to; if(v==fa[u]) continue; dfs(v); } if(st[u]) t1[dep[u]]+=st[u]; if(len[u].size()) for(int i=0;i<len[u].size();i++) { int Dis=len[u][i]; t2[Dis-dep[u]+mxn]++; } ans[u]=t1[dep[u]+w[u]]+t2[w[u]-dep[u]+mxn]-bf; if(lcau[u].size()) { for(int i=0;i<lcau[u].size();i++) { int start=lcau[u][i]; int enddd=lcav[u][i].v; int num=lcav[u][i].i; if(dep[u]+w[u]==dep[start]&&dis[num]-dep[enddd]+dep[u]==w[u]) ans[u]--; t1[dep[start]]--; t2[dis[num]-dep[enddd]+mxn]--; } } } int main() { n=read(); m=read(); for(int i=1,u,v;i<n;i++) { u=read(); v=read(); add(u,v); add(v,u); } for(int i=1;i<=n;i++) w[i]=read(); int l; dfs1(1,0); dfs2(1,0); for(int i=1,S,T;i<=m;i++) { S=read(); T=read(); st[S]++; int f=lca(S,T); lcau[f].push_back(S); lcav[f].push_back(nd{i,T}); dis[i]=dep[S]+dep[T]-(dep[f]<<1); len[T].push_back(dis[i]); } dfs(1); for(int i=1;i<=n;i++) printf("%d ",ans[i]); return 0; }