[U53204] 樹上揹包的優化

題目連接
本文旨在介紹樹上揹包的優化。
可見例題,例題中 N , M [ 1 , 100000 ] N,M \in [1,100000] 的數據量讓 O ( n m 2 ) O(nm^2) 的樸素樹上揹包T到飛起,咱們須要考慮優化。
我的會將各類優化講到極限(固然是本蒟蒻的極限)。
根據一番學習,我也認爲上下界優化最簡單易理解……
上下界優化這位神犇的博客至關不錯了:戳我%他
我也口胡兩句吧。
普通作法:html

for (j=m+1;j>=1;--j)//枚舉揹包容量
	for (k=1;k<j;++k)//枚舉在子樹中選擇多少
		f[u][j]=max(f[u][j],f[u][k]+f[v][j-k]);

那麼size優化很是簡單好想:node

for (j=min(m+1,size[u]);j>=1;--j)//枚舉揹包容量
	for (k=1;k<j&&k<=size[v];++k)//枚舉在子樹中選擇多少
		f[u][j]=max(f[u][j],f[u][k]+f[v][j-k]);

道理也很簡單,選完就那麼多,確定不能枚舉到超過的。
因而能AC這道題,用時 17 s 17s
但再想一想,咱們選擇的 k k 的下界其實也是會被約束的。
由於選到 j j 的總容量的時候,假定前面的所有取完, k k 都必需要到達一個值才能知足條件。
例子:
s i z e [ u ] = s i z e [ s o n 1 ] + s i z e [ s o n 2 ] + s i z e [ s o n 3 ] size[u] = size[son1] + size[son2] + size[son3]
咱們枚舉時,好比 j = s i z e [ s o n 1 ] + s i z e [ s o n 2 ] + a j = size[son1] + size[son2] + a 的狀況,
咱們至少要在 s o n 3 son3 中取a個節點才能達到此容量。
所以就能獲得上下界優化:web

void dfs(int u)
{
    siz[u]=1;
    f[u][1]=a[u];
    int i,j,k,v;
    for (i=head[u];i;i=nxt[i])
    {
        v=to[i];
        dfs(v);
        for (j=min(m+1,siz[u]+siz[v]);j>=2;--j)//這裏作了小改動,由於1的更新確定沒有意義
            for (k=max(1,j-siz[u]);k<=siz[v]&&k<j;++k)
                f[u][j]=max(f[u][j],f[u][j-k]+f[v][k]);
        siz[u]+=siz[v];
    }
}

這裏對 s i z e size 數組的更新作了特殊處理,能夠更方便地獲得前面全部子樹的節點數總和。因而更進一步,達到了12s的成績。
那麼還能不能更快呢?實際上是能夠的。
咱們發現內層循環須要2個判斷語句,有什麼辦法縮成一個?
固然能夠開臨時變量來存,但咱們甚至能夠換一種dp方式!(思路來源於某位神犇,他的代碼用了刷表法無師自通地進行了 O ( n m ) O(nm) 優化致使過去「指點」的我轉爲「%%%」狀態)數組

刷表法

刷表法怎麼寫呢?其實也很簡單:app

void dfs(int u)
{
    siz[u]=1;
    f[u][1]=a[u];
    int i,j,k,v;
    for (i=head[u];i;i=nxt[i])
    {
        v=to[i];
        dfs(v);
        for (j=min(m,siz[u]);j>=1;--j)//在以前子樹&&根中選擇的節點數,這裏要取1是由於確定要取根節點
            for (k=1;k<=siz[v]&&j+k<=m+1;++k)//在當前子樹取得節點數
                f[u][j+k]=max(f[u][j+k],f[u][j]+f[v][k]);
        siz[u]+=siz[v];
    }
}

這時候,咱們就能夠將內層循環的兩個判斷語句合爲一個了:svg

void dfs(int now)
{
    size[now] = 1;
    f[now][1] = w[now];
    int v;
    for (int p = head[now]; p; p = lines[p].next)
    {
        v = lines[p].to;
        dfs(v);
        for (int j = min(size[now], m); j; --j)
            for (int k = min(size[v], m + 1 - j); k; --k)
                f[now][j + k] = max(f[now][j + k], f[now][j] + f[v][k]);
        size[now] += size[v];
    }
}

省去了一個判斷,對常數的優化仍是不可小覷的。函數

下標映射

對於例題,因爲 n , m n,m 過大,開二維確定開不下,確定要扁平化爲一維。
由於有一個超級源點,所以揹包最大容量其實爲 m + 1 m+1 ,而 [ 0 , m + 1 ] [0,m+1] 間有 m + 2 m+2 個位置。
故有:oop

inline int pos(const int &x,const int &y)
{
	return x * (m+2) + y;//注意此處x可能爲0
}

可是事實上,每次都計算這個pos帶來了大量的計算。多大量呢?
當初用填表法時,我將這個函數換成了 d e f i n e define ,總時間從 12 s 12s 提高到了 8 s 8s
顯然由於這個 p o s pos 反覆計算,消耗了大量的時間。
那麼是否還有比宏定義更優的方法呢?我翻了翻最優解,除了題目做者本人在調整數據規模時的弱數據AC外,第一位是一位名爲WarlockAkk的神犇,用時僅 4.2 s 4.2s
這到底是何等黑魔法?我點開源碼開始膜拜,因而看到:學習

bfo(i,0,n+1){
        d[i]=spa+idx;
		idx+=m+2;
	}

這是什麼意思呢? d d 是一個 i n t int* 的數組,因而我恍然大悟:優化

能夠預處理出一個映射數組,將二維的對映射數組的訪問映射到一維的保存數組中。

具體實現方式:

int dp[100001000];
int *f[MAXN]; //f[i][j] points to the dp arr.
int k, pointer = 0;
    f[0] = &dp[0]; //special
    for (int i = 1; i <= n; ++i)
    {
        pointer += m + 2;
        f[i] = &dp[pointer]; //special
    }

咱們將這兩行代碼插入到讀入的循環中,就能夠獲得映射數組 f f ,咱們就能直接用 f [ i ] [ j ] f[i][j] 來訪問了!
而且由於 f [ i ] f[i] 存的索引直接加上 j j 就能獲得地址,咱們實際上避免了兩個大數的乘法,而使其變成了加法。
舉例:
原先訪問方式:
d p [ x ( m + 2 ) + y ] dp[x * (m+2) + y ] 進行了一次乘法一次加法
解析一下就是:

return dp + (x * (m+2) + y);

而如今的訪問方式:
( f [ x ] + y ) (f[x] + y)
解析一下就是:

return (f + x) + y;

效率提高至關顯著。
同時注意咱們的預處理方式:

int pointer = 0;
pointer += m + 2;

寫成加法的形式,與乘法形式對比:

pointer = (m + 2) * i;

效率如何很顯然了。
那麼下標映射後到底有多快呢?
有多快呢?
咱們看結論吧。

總結

填表法 填表法 with O(2) 刷表法 刷表法 with O(2) 下標映射 + 刷表法 with O(2)
8 s 8s 7.5 s 7.5s 7 s 7s 6.5 s 6.5s 2.4 s 2.4s

能夠發現,吸氧對於這種狀況提高不明顯。
而下標映射 快、極快、巨快!
最優解
所以在卡常優化時咱們能夠多想一想使用指針等玄學進行優化,每每會有意想不到的提高。
l o w e r _ b o u n d lower\_bound 等函數直接使用迭代器等……
That’s all.

Code

#pragma GCC target("avx")
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline")
#pragma GCC optimize("-fgcse")
#pragma GCC optimize("-fgcse-lm")
#pragma GCC optimize("-fipa-sra")
#pragma GCC optimize("-ftree-pre")
#pragma GCC optimize("-ftree-vrp")
#pragma GCC optimize("-fpeephole2")
#pragma GCC optimize("-ffast-math")
#pragma GCC optimize("-fsched-spec")
#pragma GCC optimize("unroll-loops")
#pragma GCC optimize("-falign-jumps")
#pragma GCC optimize("-falign-loops")
#pragma GCC optimize("-falign-labels")
#pragma GCC optimize("-fdevirtualize")
#pragma GCC optimize("-fcaller-saves")
#pragma GCC optimize("-fcrossjumping")
#pragma GCC optimize("-fthread-jumps")
#pragma GCC optimize("-funroll-loops")
#pragma GCC optimize("-fwhole-program")
#pragma GCC optimize("-freorder-blocks")
#pragma GCC optimize("-fschedule-insns")
#pragma GCC optimize("inline-functions")
#pragma GCC optimize("-ftree-tail-merge")
#pragma GCC optimize("-fschedule-insns2")
#pragma GCC optimize("-fstrict-aliasing")
#pragma GCC optimize("-fstrict-overflow")
#pragma GCC optimize("-falign-functions")
#pragma GCC optimize("-fcse-skip-blocks")
#pragma GCC optimize("-fcse-follow-jumps")
#pragma GCC optimize("-fsched-interblock")
#pragma GCC optimize("-fpartial-inlining")
#pragma GCC optimize("no-stack-protector")
#pragma GCC optimize("-freorder-functions")
#pragma GCC optimize("-findirect-inlining")
#pragma GCC optimize("-fhoist-adjacent-loads")
#pragma GCC optimize("-frerun-cse-after-loop")
#pragma GCC optimize("inline-small-functions")
#pragma GCC optimize("-finline-small-functions")
#pragma GCC optimize("-ftree-switch-conversion")
#pragma GCC optimize("-foptimize-sibling-calls")
#pragma GCC optimize("-fexpensive-optimizations")
#pragma GCC optimize("-funsafe-loop-optimizations")
#pragma GCC optimize("inline-functions-called-once")
#pragma GCC optimize("-fdelete-null-pointer-checks")
#include <cstdio>
using namespace std;
const int MAXN = 100100;
inline int max(const int &a, const int &b) { return a > b ? a : b; }
inline int min(const int &a, const int &b) { return a < b ? a : b; }
char buf[100000], *p1 = buf, *p2 = buf;
#define nc() p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? EOF : *p1++
template <typename T>
inline void read(T &r)
{
    static char c;r = 0;
    for (c = nc(); c > '9' || c < '0'; c = nc());
    for (; c >= '0' && c <= '9'; r = (r << 1) + (r << 3) + (c ^ 48), c = nc());
}
struct node
{
    int to, next;
    node() {}
    node(const int &_to, const int &_next) : to(_to), next(_next) {}
} lines[MAXN];
int head[MAXN];
void add(const int &x, const int &y)
{
    static int tot = 0;
    lines[++tot] = node(y, head[x]), head[x] = tot;
}
int n, m;
int dp[100001000];
int *f[MAXN]; //f[i][j] points to the dp arr.
int size[MAXN], w[MAXN];
void dfs(int now)
{
    int v;
    size[now] = 1;
    f[now][1] = w[now];
    for (int p = head[now]; p; p = lines[p].next)
    {
        v = lines[p].to;
        dfs(v);
        for (int i = min(size[now], m); i; --i)
            for (int j = min(size[v], m + 1 - i); j; --j)
                f[now][i + j] = max(f[now][i + j], f[now][i] + f[v][j]);
        size[now] += size[v];
    }
}
int main()
{
    read(n);
    read(m);
    int k, pointer = 0;
    f[0] = &dp[0]; //special
    for (int i = 1; i <= n; ++i)
    {
        pointer += m + 2;
        f[i] = &dp[pointer]; //special
        read(k);
        add(k, i); //we can set the point(0) into a vitual node,which is the root of the tree
        read(w[i]);
    }
    dfs(0);
    printf("%d", f[0][m + 1]);
    return 0;
}~~~
相關文章
相關標籤/搜索