有一棵樹,每條邊都有一個邊權,如今你要修改邊權,使得修改後根到全部葉子的距離相等。c++
要求全部邊權非負。函數
修改的代價爲$\lvert$每條邊修改前的邊權$-$修改後的邊權$\rvert$之和。spa
$n+m\leq 300000$code
容易發現,設 $f(x)$ 爲根到全部葉子的距離爲 $x$ 時的最小代價,那麼 $f(x)$是一個下凸函數,而且每一段都是線性的。get
考慮一個點 $u$ 從兒子 $v$ 轉移過來。這個過程分兩步:string
把 $v$ 的凸包加上 $u\to v$ 這條邊:it
要從 $f(x)$ 轉移到 $f'(x)$io
假設原來 $f(x)$ 的最小值是在 $[l,r]$ 時取到的,那麼:function
$x\leq l$:$f'(x)=f(x)+w$:最優方案是把這條邊的長度減到 $0$(由於邊權不能是負數)class
$l\leq x\leq l+w$:$f'(x)=f(l)+w-(x-l)$:把這條邊的代價減掉$w-(x-l)$
$l+w\leq x\leq r+w$:$f'(x)=f(l)$:這條邊的代價不須要變
$x\geq r+w$:$f'(x)=f(l)+(x-R)-w$:把這條邊的代價減掉$(x-r)-w$
那麼就是把 $[l,r]$ 這段往右平移,把 $[0,l]$ 這段往上平移,加入一段斜率爲 $1$ 的直線和一段斜率爲 $-1$的直線。
考慮怎麼維護這個凸包。
能夠發現相鄰兩段的斜率之差爲 $1$,因此只須要維護凸包上相鄰兩個線段交點的橫座標便可。
還能夠發現凸包最右邊那條直線的斜率就是這個點的兒子個數。
因此直接把最右邊兒子個數 $-1$ 條個交點彈掉就能找到 $[l,r]$ 了。
把兩個凸包合併:
直接把全部交點相加就行了。
那麼要怎麼計算答案呢?
先找到 $[l,r]$,而後對於左邊的每個交點 $v$,它的貢獻就是 $-v$。
直接相加就行了。
能夠用可合併堆實現,複雜度爲 $O((n+m)\log (n+m))$
可是我懶。
#include<cstdio> #include<cstring> #include<algorithm> #include<cstdlib> #include<ctime> #include<utility> #include<cmath> #include<functional> #include<queue> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<int,int> pii; typedef pair<ll,ll> pll; void sort(int &a,int &b) { if(a>b) swap(a,b); } void open(const char *s) { #ifndef ONLINE_JUDGE char str[100]; sprintf(str,"%s.in",s); freopen(str,"r",stdin); sprintf(str,"%s.out",s); freopen(str,"w",stdout); #endif } int rd() { int s=0,c,b=0; while(((c=getchar())<'0'||c>'9')&&c!='-'); if(c=='-') { c=getchar(); b=1; } do { s=s*10+c-'0'; } while((c=getchar())>='0'&&c<='9'); return b?-s:s; } void put(int x) { if(!x) { putchar('0'); return; } static int c[20]; int t=0; while(x) { c[++t]=x%10; x/=10; } while(t) putchar(c[t--]+'0'); } int upmin(int &a,int b) { if(b<a) { a=b; return 1; } return 0; } int upmax(int &a,int b) { if(b>a) { a=b; return 1; } return 0; } int n,m; ll w[300010]; int f[300010]; int d[300010]; priority_queue<ll> q[300010]; int main() { open("loj2568"); scanf("%d%d",&n,&m); ll ans=0; for(int i=2;i<=n+m;i++) { scanf("%d%lld",&f[i],&w[i]); ans+=w[i]; d[f[i]]++; } for(int i=n+m;i>=2;i--) { ll l=0,r=0; if(i<=n) { while(--d[i]) q[i].pop(); l=q[i].top(); q[i].pop(); r=q[i].top(); q[i].pop(); } q[i].push(l+w[i]); q[i].push(r+w[i]); if(q[i].size()>q[f[i]].size()) q[i].swap(q[f[i]]); while(!q[i].empty()) q[f[i]].push(q[i].top()),q[i].pop(); } while(d[1]--) q[1].pop(); while(!q[1].empty()) ans-=q[1].top(),q[1].pop(); printf("%lld\n",ans); return 0; }