傳送門c++
這題考試的時候以爲時間複雜度假了,\(n \geqslant 1000\)的部分直接瞎寫了個特殊性質上去,結果假的時間複雜度能有60pts……git
首先70pts能夠枚舉起點,每次跑一遍dfs
令 \(g[i][j]\) 表示以 \(i\) 爲起點,撒 \(j\) 次麪包屑獲得的最大收益便可
這部分代碼:數組
#include <bits/stdc++.h> using namespace std; #define INF 0x3f3f3f3f #define N 100010 #define ll long long #define ld long double #define usd unsigned #define ull unsigned long long //#define int long long #define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++) char buf[1<<21], *p1=buf, *p2=buf; inline int read() { int ans=0, f=1; char c=getchar(); while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();} while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();} return ans*f; } int n, V; int p[N], head[N], size; struct edge{int to, next; bool vis;}e[N<<1]; inline void add(int s, int t) {edge* k=&e[++size]; k->to=t; k->next=head[s]; head[s]=size;} namespace force{ ll ans; bool none[N]; void dfs(int u, int fa, int v2, ll sum) { if (v2<=0) {ans=max(ans, sum); return ;} bool cge=0; if (!none[u]) none[u]=1, cge=1; for (int i=head[u],v; i; i=e[i].next) { v = e[i].to; if (v!=fa) dfs(v, u, v2, sum); } for (int i=head[u],v; i; i=e[i].next) if (!none[e[i].to]) {sum+=p[e[i].to]; none[e[i].to]=1; e[i].vis=1;} for (int i=head[u],v; i; i=e[i].next) { v = e[i].to; if (v!=fa) dfs(v, u, v2-1, sum); } for (int i=head[u],v; i; i=e[i].next) if (e[i].vis) {e[i].vis=0; none[e[i].to]=0;} if (cge) none[u]=0; } void solve() { //for (int i=1; i<=n; ++i) { //memset(none, 0, sizeof(none)); dfs(1, 0, V, 0); //} printf("%lld\n", ans); exit(0); } } namespace task1{ ll dp[N][105][4], ans; //ll allcnt; void dfs(int u, int fa) { //cout<<"dfs "<<u<<' '<<fa<<endl; ll sum=0; bool leaf=1; for (int i=head[u]; i; i=e[i].next) { sum+=p[e[i].to]; if (e[i].to!=fa) leaf=0; } memset(dp[u], 0, sizeof(ll)*420); dp[u][V][3]=p[fa]; if (leaf) return ; for (int i=head[u],v; i; i=e[i].next) { v = e[i].to; if (v!=fa) { dfs(v, u); for (int s=V; s>=0; --s) { //++allcnt; //dp[u][s][0]=dp[u][s][1]=dp[u][s][2]=dp[u][s][3]=0; dp[u][s][0] = max(dp[u][s][0], max(dp[v][s+1][2], dp[v][s+1][3])); dp[u][s][1] = max(dp[u][s][1], max(dp[v][s][0], dp[v][s][1])); if (s>0) { dp[u][s][2] = max(dp[u][s][2], sum-p[v]+max(dp[v][s+1][2], dp[v][s+1][3])); dp[u][s][3] = max(max(dp[u][s][3], sum-p[v]+max(dp[v][s][0], dp[v][s][1])), sum); } } } } } void solve() { //cout<<double(sizeof(dp))/1024/1024<<endl; for (int i=1; i<=n; ++i) { //cout<<i<<endl; if (clock()>=1600000) {printf("%lld\n", ans); exit(0);} //memset(dp, 0, sizeof(dp)); dfs(i, 0); for (int j=0; j<=V; ++j) ans=max(ans, max(max(dp[i][j][0], dp[i][j][1]), max(dp[i][j][2], dp[i][j][3]))); } //int rt=2; //dfs(rt, 0); //for (int j=0; j<=V; ++j) ans=max(ans, max(max(dp[rt][j][0], dp[rt][j][1]), max(dp[rt][j][2], dp[rt][j][3]))); #if 0 for (int i=1; i<=n; ++i) { for (int j=0; j<=V; ++j) { for (int k=0; k<4; ++k) cout<<dp[i][j][k]<<' '; cout<<endl; } } #endif //for (int j=0; j<=V; ++j) {for (int k=0; k<4; ++k) cout<<dp[2][j][k]<<' '; cout<<endl;} printf("%lld\n", ans); //cout<<"allcnt: "<<allcnt<<endl; exit(0); } } namespace task2{ ll f[N][105], g[N][105], ans; void dfs(int u, int fa) { //cout<<"dfs "<<u<<' '<<fa<<endl; ll sum=0; for (int i=head[u]; i; i=e[i].next) sum+=p[e[i].to]; memset(f[u], 0, sizeof(ll)*105); f[u][1]=sum; for (int i=head[u],v; i; i=e[i].next) { v = e[i].to; if (v==fa) continue; memset(g[v], 0, sizeof(ll)*105); //for (int j=1; j<=V; ++j) { // g[v][j] = max(g[v][j], max(g[u][j], sum-p[fa]+g[u][j-1])); //} dfs(v, u); for (int j=1; j<=V; ++j) { //f[u][j] = max(f[u][j], max(f[v][j], sum-p[v]+f[v][j-1])); g[u][j] = max(g[u][j], max(g[v][j], sum-p[fa]+g[v][j+1])); } } } void solve() { for (int i=1; i<=n; ++i) { if (clock()>=1600000) {printf("%lld\n", ans); exit(0);} memset(g[i], 0, sizeof(ll)*105); dfs(i, 0); for (int j=1; j<=n; ++j) for (int k=0; k<=V; ++k) ans=max(ans, max(f[j][k], g[j][k])); //for (int j=0; j<=V; ++j) cout<<g[3][j]<<' '; cout<<endl; } printf("%lld\n", ans); exit(0); } } namespace task{ ll f[N][105], g[N][105], ans; void dfs(int u, int fa) { //cout<<"dfs "<<u<<' '<<fa<<endl; ll sum=0; for (int i=head[u]; i; i=e[i].next) sum+=p[e[i].to]; f[u][1]=sum; ll maxn[105][4], maxi[105][4]; memset(maxn, 0, sizeof(maxn)); memset(maxi, 0, sizeof(maxi)); for (int i=head[u],v; i; i=e[i].next) { v = e[i].to; if (v==fa) continue; for (int j=1; j<=V; ++j) { g[v][j] = max(g[v][j], max(g[u][j], sum-p[fa]+g[u][j-1])); if (g[v][j]>=maxn[j][0]) {maxn[j][1]=maxn[j][0]; maxi[j][1]=maxi[j][0]; maxn[j][0]=g[v][j]; maxi[j][0]=v;} else if (g[v][j]>maxn[j][1]) maxn[j][1]=g[v][j], maxi[j][1]=v; } dfs(v, u); for (int j=1; j<=V; ++j) { f[u][j] = max(f[u][j], max(f[v][j], sum-p[v]+f[v][j-1])); if (f[v][j]>=maxn[j][2]) {maxn[j][3]=maxn[j][2]; maxi[j][3]=maxi[j][2]; maxn[j][2]=f[v][j]; maxi[j][2]=v;} else if (f[v][j]>maxn[j][3]) maxn[j][3]=f[v][j], maxi[j][3]=v; } } cout<<"u: "<<u<<endl; for (int j=0; j<V; ++j) { if (maxi[j+1][0]!=maxi[j][2]) ans=max(ans, maxn[j+1][0]+maxn[j][2]), cout<<"try1: "<<maxn[j+1][0]<<' '<<maxi[j+1][0]<<' '<<maxn[j][2]<<' '<<maxi[j][2]<<' '<<maxn[j+1][0]+maxn[j][2]<<endl; else { if (maxi[j+1][1]!=maxi[j][2]) ans=max(ans, maxn[j+1][1]+maxn[j][2]), cout<<"try2: "<<maxn[j+1][1]<<' '<<maxi[j+1][1]<<' '<<maxn[j][2]<<' '<<maxi[j][2]<<' '<<maxn[j+1][1]+maxn[j][2]<<endl; if (maxi[j+1][0]!=maxi[j][3]) ans=max(ans, maxn[j+1][0]+maxn[j][3]), cout<<"try3: "<<maxn[j+1][0]<<' '<<maxn[j][3]<<' '<<maxn[j+1][0]+maxn[j][3]<<endl; } ans=max(ans, max(maxn[j][0], maxn[j][2])); } } void solve() { dfs(1, 0); printf("%lld\n", ans); exit(0); } } signed main() { #ifdef DEBUG freopen("1.in", "r", stdin); #endif n=read(); V=read(); for (int i=1; i<=n; ++i) p[i]=read(); for (int i=1,u,v; i<n; ++i) { u=read(); v=read(); add(u, v); add(v, u); } task2::solve(); return 0; }
而後考慮如何不枚舉起點
那就須要換根DP了,令 \(f[i][j]\) 表示從以i爲根的子樹中走到 \(i\) ,$ g[i][j]$ 表示從i的父親走到 \(i\) 及其子樹中撒 \(j\) 次的最大收益
轉移的時候要特別注意前後順序
首先方程有了,ans=max(ans, f[u][j]+g[to][v-j])
而咱們要在同一次遍歷中更新 \(ans,f[u][j],g[u][j]\)
由於f和g確定不能選來自同一棵子樹的,因此f要用從以前遍歷過的子樹中的,因此先更新ans,再轉移f,g
發現這樣只是在用一個g匹配它左邊的全部f,顯然不夠,因此還要逆序枚舉一遍
挺有思惟量的,作了巨久……還由於變量名重了沒看出來陷入高度自閉spa
Code:code
#include <bits/stdc++.h> using namespace std; #define INF 0x3f3f3f3f #define N 100010 #define ll long long //#define int long long char buf[1<<21], *p1=buf, *p2=buf; #define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++) inline int read() { int ans=0, f=1; char c=getchar(); while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();} while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();} return ans*f; } int n, v; int head[N], size, sta[N], top; ll p[N], sum[N], ans, f[N][105], g[N][105]; struct edge{int to, next;}e[N<<1]; inline void add(int s, int t) {edge* k=&e[++size]; k->to=t; k->next=head[s]; head[s]=size;} void dfs(int u, int fa) { for (int i=head[u],v; i; i=e[i].next) { v = e[i].to; sum[u]+=p[v]; if (v!=fa) dfs(v, u); } int to; for (int i=1; i<=v; ++i) f[u][i]=sum[u], g[u][i]=sum[u]-p[fa]; for (int i=head[u]; i; i=e[i].next) { to = e[i].to; if (to==fa) continue; sta[++top]=to; for (int j=1; j<=v; ++j) { ans = max(ans, f[u][j]+g[to][v-j]); f[u][j]=max(f[u][j], max(f[to][j], f[to][j-1]+sum[u]-p[to])); g[u][j]=max(g[u][j], max(g[to][j], g[to][j-1]+sum[u]-p[fa])); } } ans = max(ans, max(f[u][v], g[u][v])); for (int i=1; i<=v; ++i) f[u][i]=sum[u], g[u][i]=sum[u]-p[fa]; while (top) { to=sta[top--]; for (int j=1; j<=v; ++j) { ans = max(ans, f[u][j]+g[to][v-j]); f[u][j]=max(f[u][j], max(f[to][j], f[to][j-1]+sum[u]-p[to])); g[u][j]=max(g[u][j], max(g[to][j], g[to][j-1]+sum[u]-p[fa])); } } ans = max(ans, max(f[u][v], g[u][v])); } signed main() { n=read(); v=read(); for (int i=1; i<=n; ++i) p[i]=read(); for (int i=1,u,v; i<n; ++i) {u=read(); v=read(); add(u, v); add(v, u);} dfs(1, 0); printf("%lld\n", ans); return 0; }