凸優化小結

本文參考自 Wearry 在集訓的講解《DP及其優化》。git

簡介

凸優化解決的是一類選擇剛好 \(K\) 個某種物品的最優化問題 , 通常來講這樣的題目在不考慮物品數量限制的條件下會有一個隱性的圖像 , 表示選擇的物品數量與問題最優解之間的關係 .api

每一個點就是選了 \(K\) 個物品的最優Dp值。(答案)也就是 \((K, f(K))\)優化

問題可以用凸優化解決還須要知足圖像是凸的 , 直觀地理解就是選的物品越多的狀況下多選一個物品 , 最優解的增加速度會變慢 .spa

解法

解決凸優化類型的題目能夠採用二分的方法 , 即二分隱性凸殼上最優值所在點的斜率 , 而後忽略剛好 \(K\) 個的限制作一次原問題 .debug

這樣每次選擇一個物品的時候要多付出斜率大小的代價 , 就可以根據最優狀況下選擇的物品數量來判斷二分的斜率與實際最優值的斜率的大小關係 .code

理論上這個斜率必定是整數 , 因爲題目性質可能會出現二分不出這個數的狀況 , 這時就須要一些實現上的技巧保證可以找到這個最優解 .blog

由於相鄰兩個點橫下標差 \(1\) (多選一個),縱座標都是整數。(對於大部分的題目最優解都是整數)。排序

這個也就是 CTSC 上講的 帶權二分 啦。get

例題

UOJ #104. 【APIO2014】Split the sequence

題意

將一個長爲 \(n\) 的序列分紅 \(k+1\) 個塊,每次分割獲得分割處 左邊的和 與 右邊的和 乘積的分數。

保證序列中每一個數非負。最後須要最大化分數,須要求出任意一組方案。

\(2 \le n \le 10^5, 1 \le k \le \min \{n - 1, 200\}\)

題解

直接作斜率優化是 \(O(nk)\) 的,那個十分 簡單 ,注意細節就好了。能夠參考 個人代碼

雖然已通過了這題了,可是有更好的作法。也就是對於 \(k \le n - 1\) 也就是 \(k,n\) 同級的時候有更好的作法。

考慮前面講的凸優化,咱們考慮二分那個斜率,也就是分數的增加率。

假設二分的值爲 \(mid\) ,至關於轉化成沒有分段次數的限制,可是每次分段都要額外付出 \(mid\) 的代價 , 求最大化收益的前提下分段數是多少 .

具體化來講,就例如上圖,那個上凸殼就是答案的圖像,咱們當前二分的那個斜率的直線就是那條紅線。

咱們當前是最大化 \(f(x) - x\times mid\)

那麼咱們考慮把紅線向上不斷平移,那麼最後接觸到的點就是這條直線與上凸殼的切點。此時答案最大。

那麼咱們算出的分段數就是 \(x\) ,也就是切點的下標。而後比較一下 \(x\)\(k\) 的關係,判斷應該向哪邊移動。

而後最後獲得斜率算出的方案就是最優方案了。

我沒有寫 但據說細節特別多,輸出方案很噁心。若是想寫的話,能夠看下 UOJ 最快的代碼,來自同屆大佬 yww 的。

這個複雜度就是 \(O(n \log w)\) 的,十分優秀。

CF739E Gosha is hunting

題意

你要抓神奇寶貝! 如今一共有 \(n\) 只神奇寶貝。 你有 \(a\) 個『寶貝球』和 \(b\) 個『超級球』。 『寶貝球』抓到第 \(i\) 只神奇寶貝的機率是 \(p_i\) ,『超級球』抓到的機率則是 \(u_i\) 。 不能往同一只神奇寶貝上使用超過一個同種的『球』,可是能夠往同一只上既使用『寶貝球』又使用『超級球』(都抓到算一個)。 請合理分配每一個球抓誰,使得你抓到神奇寶貝的總個數指望最大,並輸出這個值。

\(n \le 2000\)

題解

不難發現用的球越多,指望增加率越低。這是很好理解的,一開始確定選更優的神奇寶貝球,而後再選較劣的神奇寶貝球。

這就意味着這個隱性的圖像是上凸的,咱們能夠相似於上題的套路,咱們二分那個斜率。

而後咱們就能夠忽略個數的限制了。但此處這裏有兩個變量,那麼咱們二分套二分就好了。

假設當前二分的是 \(mid\) ,那麼咱們每次選擇一個神奇寶貝球就要付出 \(mid\) 的代價。

而後求出最大化收益時須要選多少個神奇寶貝球就好了,這個能夠用一個很容易的 dp 求出。

但注意兩個同時選的時候,機率應該是 \(p_a + p_b - p_a \times p_b\)

但此時有一個重要的細節,就是二分到最後斜率求出的答案不必定是正確的。

可是在其中若是咱們二分到 最優解要選的球和我最後用的球同樣的話,那麼這樣就是一個最優的可行解。

至於緣由?無可奉告!

彷佛是可能有三點共線的狀況,此時選的個數有問題。而且最後須要用給你的個數,不能用求出的個數。

代碼

具體看看代碼。。。反正我也不知道爲何這麼多特殊狀況。

#include <bits/stdc++.h>

#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)

using namespace std;

inline bool chkmax(double &a, double b) {return b > a ? a = b, 1 : 0;}

inline int read() {
    int x = 0, fh = 1; char ch = getchar();
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') fh = -1;
    for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
    return x * fh;
}

void File() {
#ifdef zjp_shadow
    freopen ("E.in", "r", stdin);
    freopen ("E.out", "w", stdout);
#endif
}

const double eps = 1e-10;

const int N = 2010;

int n, a, b;

double pa[N], pb[N]; int usea, useb; double f;

void Calc(double costa, double costb) {
    f = 0; usea = useb = 0;
    For (i, 1, n) {
        int cura = 0, curb = 0; double res = 0;
        if (chkmax(res, pa[i] - costa)) cura = 1, curb = 0;
        if (chkmax(res, pb[i] - costb)) cura = 0, curb = 1;
        if (chkmax(res, pa[i] + pb[i] - pa[i] * pb[i] - (costa + costb))) cura = curb = 1;
        usea += cura; useb += curb; f += res;
    }
}

int main () {

    File();

    n = read(); a = read(); b = read();
    For (i, 1, n) scanf("%lf", &pa[i]);
    For (i, 1, n) scanf("%lf", &pb[i]);

    double la = 0, ra = 1, lb, rb;
    while (la + eps < ra) {
        double mida = (la + ra) / 2.0; lb = 0, rb = 1;
        while (lb + eps < rb) {
            double midb = (lb + rb) / 2.0;
            Calc(mida, midb);
            if (useb == b) {lb = midb; break; }
            if (useb < b) rb = midb; else lb = midb;
        }
        if (usea == a) { la = mida; break; }
        if (usea < a) ra = mida; else la = mida;
    }
    Calc(la, lb);
    printf ("%.10lf\n", f + la * a + lb * b);

    return 0;
}

LOJ #2478. 「九省聯考 2018」林克卡特樹

題意

LOJ #2478. 「九省聯考 2018」林克卡特樹

請點上面連接qwq 題意很好理解的。(但要認真看題)

題解

題意等價於,剛好選 \(k\) 條鏈, 使得他們的長度和最大。

咱們一樣可使用凸優化對於這個來進行優化。

二分那個斜率 \(mid\) ,每次選擇多一條鏈就要減去 \(mid\) ,最後求使得答案最優的時候,須要分紅幾段。

但這些都不是重點,重點是如何求出答案最優的時候有多少段。

咱們令 dp[u][0/1/2]\(u\) 這個點,向子樹中延伸出 \(0,1,2\) 條鏈。

轉移的話,枚舉一下它從和哪一個兒子的鏈相連,計算一下分的段數便可。

爲了方便計算段數,在鏈的底部統計上段數,因此合併兩條鏈的時候須要減去一段,而且把權值加回來 \(mid\)

記得要統計上別的子樹的答案!!先掛下 \(dp\) 的代碼吧。

利用 std :: pair<ll, int> 寫的更加方便,第一維表示答案,第二維表示段數。

typedef pair<ll, int> PLI;
#define res first
#define num second
#define mp make_pair

inline PLI operator + (const PLI &lhs, const PLI &rhs) {
    return mp(lhs.res + rhs.res, lhs.num + rhs.num);
}

PLI f[N][3]; ll del;
void Dp(int u = 1, int fa = 0) {
    f[u][0] = mp(0, 0);
    f[u][1] = mp(- del, 1);
    f[u][2] = mp(- inf, 0);

    for (register int i = Head[u]; i; i = Next[i]) {
        register int v = to[i]; if (v == fa) continue ; Dp(v, u);
        PLI tmp = max(f[v][0], max(f[v][1], f[v][2]));

        chkmax(f[u][2], f[u][2] + tmp);
        chkmax(f[u][2], f[u][1] + f[v][1] + mp(val[i] + del, -1));

        chkmax(f[u][1], f[u][1] + tmp);
        chkmax(f[u][1], f[u][0] + f[v][1] + mp(val[i], 0));
        chkmax(f[u][1], f[u][0] + f[v][0] + mp(- del, 1));

        chkmax(f[u][0], f[u][0] + tmp);
    }
}

而後又會有三點共線的狀況,也就是對於選擇連續幾個答案都是相同的。

咱們發現,利用 std :: pair<ll, int> 的運算符 < ,會在第一維答案相同時優先第二維段數小的在前。

因此咱們更新答案的時候就須要在 \(use > k\) 也就是需求大於供給 通貨膨脹 的時候進行更新,否則答案可能更新不到。

若是 \(use = k\) 那麼就能夠直接退出輸出答案就行啦。

代碼

#include <bits/stdc++.h>

#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)

using namespace std;

typedef long long ll;
template<typename T> inline bool chkmax(T &a, T b) {return b > a ? a = b, 1 : 0;}

namespace pb_ds
{   
    namespace io
    {
        const int MaxBuff = 1 << 15;
        const int Output = 1 << 23;
        char B[MaxBuff], *S = B, *T = B;
#define getc() ((S == T) && (T = (S = B) + fread(B, 1, MaxBuff, stdin), S == T) ? 0 : *S++)
        char Out[Output], *iter = Out;
        inline void flush()
        {
            fwrite(Out, 1, iter - Out, stdout);
            iter = Out;
        }
    }

    inline int read()
    {
        using namespace io;
        register char ch; register int ans = 0; register bool neg = 0;
        while(ch = getc(), (ch < '0' || ch > '9') && ch != '-')     ;
        ch == '-' ? neg = 1 : ans = ch - '0';
        while(ch = getc(), '0' <= ch && ch <= '9') ans = ans * 10 + ch - '0';
        return neg ? -ans : ans;
    }
};

using namespace pb_ds;

void File () {
#ifdef zjp_shadow
    freopen ("2478.in", "r", stdin);
    freopen ("2478.out", "w", stdout);
#endif
}

const int N = 3e5 + 1e3, M = N << 1;

int Head[N], Next[M], to[M], val[M], e = 0;
inline void add_edge(int u, int v, int w) {
    to[++ e] = v; Next[e] = Head[u]; Head[u] = e; val[e] = w;
}

inline void Add(int u, int v, int w) {
    add_edge(u, v, w); add_edge(v, u, w);
}

typedef long long ll;
const ll inf = 1e18;

typedef pair<ll, int> PLI;
#define res first
#define num second
#define mp make_pair

inline PLI operator + (const PLI &lhs, const PLI &rhs) {
    return mp(lhs.res + rhs.res, lhs.num + rhs.num);
}

PLI f[N][3]; ll del;
void Dp(int u = 1, int fa = 0) {
    f[u][0] = mp(0, 0);
    f[u][1] = mp(- del, 1);
    f[u][2] = mp(- inf, 0);

    for (register int i = Head[u]; i; i = Next[i]) {
        register int v = to[i]; if (v == fa) continue ; Dp(v, u);
        PLI tmp = max(f[v][0], max(f[v][1], f[v][2]));

        chkmax(f[u][2], f[u][2] + tmp);
        chkmax(f[u][2], f[u][1] + f[v][1] + mp(val[i] + del, -1));

        chkmax(f[u][1], f[u][1] + tmp);
        chkmax(f[u][1], f[u][0] + f[v][1] + mp(val[i], 0));
        chkmax(f[u][1], f[u][0] + f[v][0] + mp(- del, 1));

        chkmax(f[u][0], f[u][0] + tmp);
    }
}

int n, k, use; PLI ans;

void Calc(ll cur) {
    ans = mp(-inf, 0); del = cur; Dp(); 
    For (i, 0, 2) chkmax(ans, f[1][i]); use = ans.num;
}

ll Ans;
int main () {

    File();

    n = read(), k = read() + 1;
    For (i, 1, n - 1) {
        register int u = read(), v = read(), w = read(); Add(u, v, w);
    }

    ll l = -1e6, r = 8e7;
    while (l <= r) {
        ll mid = (l + r) >> 1;
        Calc(mid);
        if (use == k) return printf ("%lld\n", ans.res + mid * k), 0;
        if (use < k) r = mid - 1;
        else l = mid + 1, Ans = ans.res + mid * k;
    }
    printf ("%lld\n", Ans);

    return 0;

}

LOJ #566. 「LibreOJ Round #10」yanQval 的生成樹

題意

戳進去 >> #566. 「LibreOJ Round #10」yanQval 的生成樹

題意簡單明瞭 qwq

題解

首先,顯然有 \(\mu\) 是這些數的中位數。

而後咱們就很容易想到考慮枚舉中位數 \(mid\) ,而後在 \(w_i < mid\) (白邊)與 \(w_i \ge mid\) (黑邊)分別選 \(\displaystyle \lfloor \frac{n - 1}{2} \rfloor\) 條邊,組成最大生成樹。

這個就顯然能夠進行凸優化了,二分斜率 \(k\) ,把白邊權值 \(+k\) ,而後作最大生成樹,看選出白邊的數量與需求的關係就好了。

這樣就獲得了一個很好的 \(O(nm \log w ~\alpha (n))\) 的作法啦。(注意此處須要預處理排序,才能達到這個複雜度)

而後這樣顯然不夠,咱們繼續考慮以前的權值是什麼。白邊的權值爲 \(mid + k - w_i\) ,黑邊的爲 \(w_i - mid\) 。同時加上一個 \(mid\) 不會改變,那麼就是 \(2\times mid + k - w_i\)\(w_i\) 。咱們令 \(C=2\times mid + k\) ,那麼白邊爲 \(C - w_i\) ,黑邊爲 \(w_i\)

嘗試一下二分 \(C\) ,而後直接判斷呢?這樣看起來很不真實,但倒是對的。

這樣能夠保證在最大生成樹上 \(< mid\)\(\ge mid\) 都各有一半。爲何呢?由於你考慮不存在,那麼多的一邊存在換到另一邊會更優的狀況。

具體看官方解釋:

首先對於 \(M\) 若是最大生成樹 \(T(M)\) 含有黑邊 \(w_1-M\) 和白邊 \(M-w_2\) 且 \(w_1<w_2\) ,顯然交換兩條邊爲 \(w_2-M,M-w_1\) 更優(由於黑白邊對應重合,交換老是可行的)。故全部黑邊對應的 \(w\) 必然大於全部白邊。那麼若是最大生成樹含有 \(w< M\) 的黑邊或 \(w\ge M\) 的白邊,必然只含一種,不妨設爲黑邊。那麼設最小黑邊本來的權值爲 \(w'\) ,取 \(M'=w'\) ,能夠發現其他邊的權值之和不變,而這條黑邊的權值從 \(w'-M<0\) 變成了 \(0\) ,增長了,故獲得了一棵更大的生成樹,因此這必定不是全局最大生成樹。又因爲方案數有限全局最大生成樹(或者 \(n-2\) 條邊生成森林)必定存在,其必然僅含有 \(w\ge M\) 的黑邊和 \(w<M\) 的白邊。

那麼咱們就除掉一個 \(O(n)\) 的複雜度啦。具體看代碼實現qwq

\(n\) 爲偶數其實也是沒問題的,由於你總會選到中位數,不影響答案。

代碼

#include <bits/stdc++.h>

#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)

using namespace std;

typedef long long ll;

inline bool chkmin(int &a, int b) {return b < a ? a = b, 1 : 0;}
inline bool chkmax(int &a, int b) {return b > a ? a = b, 1 : 0;}

inline int read() {
    int x = 0, fh = 1; char ch = getchar();
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') fh = -1;
    for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
    return x * fh;
}

void File() {
#ifdef zjp_shadow
    freopen ("566.in", "r", stdin);
    freopen ("566.out", "w", stdout);
#endif
}

const int N = 2e5 + 1e3, M = 5e5 + 1e3;

int n, m;

namespace Union_Set {

    int fa[N], Size[N];

    void Init(int maxn) { For (i, 1, maxn) fa[i] = i, Size[i] = 0; }

    int find(int x) { return x == fa[x] ? x : fa[x] = find(fa[x]); }

    inline bool Union(int x, int y) {
        int rtx = find(x), rty = find(y);
        if (rtx == rty) return false;
        if (Size[rtx] < Size[rty]) swap(rtx, rty);
        Size[rtx] += Size[rty]; fa[rty] = rtx; return true;
    }

}

struct Edge {

    int u, v, w;

    inline bool operator < (const Edge &rhs) const { return w > rhs.w; }

} lt[M];

ll ans, res; int use, need;
void Work(int lim) {
    Union_Set :: Init(n); res = use = 0;
    for (register int L = 1, R = m, cur = 0; L <= R; ) {
        Edge add; register bool choose = false;
        if (lt[L].w >= lim - lt[R].w) add = lt[L ++];
        else add = lt[R --], choose = true, add.w = lim - add.w;

        if (Union_Set :: Union(add.u, add.v)) {
            res += add.w; if (choose) ++ use;
            if (++ cur == need << 1) break;
        }
    }
    res -= 1ll * lim * need;
}

int main () {

    File();

    n = read(); m = read(); need = (n - 1) >> 1; if (!need) return puts("0"), 0;
    For (i, 1, m)
        lt[i] = (Edge) {read(), read(), read()};
    sort(lt + 1, lt + m + 1);

    int l = 0, r = min(lt[1].w * 2 + 1, (int) 1e9);
    while (l <= r) {
        int mid = (l + r) >> 1; Work(mid);
        if (use == need) return printf ("%lld\n", res), 0;
        if (use < need) l = mid + 1, ans = res; else r = mid - 1;
    }
    printf ("%lld\n", ans);

    return 0;
}
相關文章
相關標籤/搜索