看上去各類二分答案+貪心啊ios
不是很可作spa
只會 O(n^2) :每次找最深的點並向上跳 mid 步,將跳到的點燃,這樣套個倍增,預先把點按深度排序的話 複雜度大概是 O(n^2*log^2) 的.net
看了一發題解依然看不懂,因而頹這題頹了一下午...code
先引用 這裏 的一句話:選擇一些點代價相同的話通常是貪心,代價不一樣通常是 dpblog
此題比較麻煩的就是須要在 O(n) 的時間內驗證答案排序
log 大概是不太可行了,沒有什麼騷操做預處理連點都掃不全rem
考慮 dfs 一遍,(如下說的點大部分都是須要覆蓋的點get
對於一個點的來講,它的子樹中若是有一個沒被覆蓋的點距它距離 > mid 了,這確定是不合法的,很容易想到咱們每次判斷最深未覆蓋點的距離,若 = mid 則把當前點點燃string
如今考慮距離小於 mid 的未覆蓋點,我須要記錄 x 的子樹中距離 x 最近的點燃的點燒到 x 後還能再燒多長,這樣可讓子樹之間互相更新,還能夠根據這個距離來更新 x it
以上信息均可以自底向上更新,因此咱們記錄 dep[x] 表示子樹中最深的未覆蓋點的距離, rem[x] 表示子樹中最淺的點燃的點燒到 x 後還能再燒多長,就能夠進行上邊的操做了
這裏我把初值設爲 -1 爲了方便區分這個點是不是須要覆蓋的點,可否向上更新信息
注意再更新到了根的時候,根的信息須要在 dfs 外額外判斷,因爲可能 dep[Root] < mid ,就是說點燃根的某個祖先來更新根的子樹中的點,因此判斷一下是否須要把根點燃便可
寫了一個萬惡的特判WA了很久...
代碼:
#include<algorithm> #include<iostream> #include<cstdlib> #include<cstring> #include<cctype> #include<cstdio> #include<cmath> using namespace std; const int MAXN = 300005; struct EDGE{ int nxt, to; EDGE(int NXT = 0, int TO = 0) {nxt = NXT; to = TO;} }edge[MAXN << 1]; int n, m, totedge; int head[MAXN], rem[MAXN], dep[MAXN]; bool has[MAXN], is[MAXN], dis[MAXN]; inline void add(int x, int y) { edge[++totedge] = EDGE(head[x], y); head[x] = totedge; return; } void predfs(int x, int fa) { for(int i = head[x]; i; i = edge[i].nxt) if(edge[i].to != fa) { int y = edge[i].to; predfs(y, x); has[x] |= has[y]; } return; } int dfs(int x, int fa, int mid) { int tot = 0; for(int i = head[x]; i; i = edge[i].nxt) if(edge[i].to != fa && has[edge[i].to]) { int y = edge[i].to; tot += dfs(y, x, mid); if(dep[y] + 1 == mid && rem[x] != mid) { rem[x] = mid; ++tot; dep[y] = -1; } rem[x] = max(rem[x], rem[y] - 1); } for(int i = head[x]; i; i = edge[i].nxt) if(edge[i].to != fa && has[edge[i].to]) { int y = edge[i].to; if(dep[y] + 1 <= rem[x]) { dep[y] = -1; continue; } if(~dep[y]) dep[x] = max(dep[x], dep[y] + 1); } if(is[x]) dep[x] = max(dep[x], 0); if(dep[x] == 0 && rem[x] >= 0) dep[x] = -1; return tot; } inline bool chk(int mid) { for(int i = 1; i <= n; ++i) rem[i] = dep[i] = -1; int tmp = dfs(1, 0, mid); if(~dep[1]) ++tmp; return (tmp <= m); } inline void hfs(int l, int r) { int mid = ((l + r) >> 1); while(l < r) { mid = ((l + r) >> 1); if(chk(mid)) r = mid; else l = mid + 1; } printf("%d\n", l); } int main() { scanf("%d%d", &n, &m); register int xx, yy = 0; for(int i = 1; i <= n; ++i) { scanf("%d", &xx); yy += xx; rem[i] = dep[i] = -1; has[i] = is[i] = xx; } if(m >= yy) { //辣雞特判毀我青春 puts("0"); return 0; } for(int i = 1; i < n; ++i) { scanf("%d%d", &xx, &yy); add(xx, yy); add(yy, xx); } predfs(1, 0); hfs(1, (n - m) / m + 1); return 0; }