UOJ 58 (樹上帶修改的莫隊)

UOJ 58 糖果公園

Problem :
給一棵n個點的樹,每一個點上有一種顏色,對於一條路徑上的點,若 i 顏色第 j 次出現對該路徑權值的貢獻爲 w[i] * c[j], 每次詢問一條路徑的權值,或者修改某個點的顏色。
Solution :
樹上的帶修改的莫隊。
使用dfs序來對左右端點進行分塊。
第一關鍵字分塊排序左端點,第二關鍵字分塊排序右端點,第三關鍵字排序詢問順序。c++

用S(v, u)表明 v到u的路徑上的結點的集合。
用root來表明根結點,用lca(v, u)來表明v、u的最近公共祖先。
那麼
S(v, u) = S(root, v) xor S(root, u) xor lca(v, u)
其中xor是集合的對稱差。
簡單來講就是節點出現兩次消掉。spa

lca很討厭,因而再定義
T(v, u) = S(root, v) xor S(root, u)
觀察將curV移動到targetV先後T(curV, curU)變化:
T(curV, curU) = S(root, curV) xor S(root, curU)
T(targetV, curU) = S(root, targetV) xor S(root, curU)
取對稱差:
T(curV, curU) xor T(targetV, curU)= (S(root, curV) xor S(root, curU)) xor (S(root, targetV) xor S(root, curU))
因爲對稱差的交換律、結合律:
T(curV, curU) xor T(targetV, curU)= S(root, curV) xor S(root, targetV)
兩邊同時xor T(curV, curU):
T(targetV, curU)= T(curV, curU) xor S(root, curV) xor S(root, targetV)
發現最後兩項很爽……哇哈哈
T(targetV, curU)= T(curV, curU) xor T(curV, targetV)
(有公式恐懼症的不要走啊 T_T)code

也就是說,更新的時候,xor T(curV, targetV)就好了。
即,對curV到targetV路徑(除開lca(curV, targetV))上的結點,將它們的存在性取反便可。排序

引用自vfkci

#include <bits/stdc++.h>
using namespace std;

const int INF = 1e9 + 7;
const int N = 1000008;
int n, m, q, block_size, block_num;

vector <int> vec[N];
int cv[N], cw[N], cnt[N], a[N], b[N], vis[N];
int belong[N], num, tot1, tot2;
int fa[N][20], dep[N];
stack <int> st;
long long ans[N], sum;

struct query1
{
    int u, v, ub, vb, x, id;
    query1(){}
    query1(int u, int v, int x, int id) : u(u), v(v), x(x), id(id)
    {
        ub = belong[u]; 
        vb = belong[v];
    }
    bool operator < (const query1 &b) const
    {
        if (ub != b.ub) return ub < b.ub;
        if (vb != b.vb) return vb < b.vb;
        return x < b.x;
    }
}q1[N];
struct query2
{
    int pos, x, y;
    query2(){}
    query2(int pos, int x, int y) : pos(pos), x(x), y(y){}
}q2[N];
void stpop(int &num)
{
    ++block_num;
    for (int i = 1; i <= num; ++i)
    {
        int p = st.top(); st.pop();
        belong[p] = block_num;
    }
    num = 0;
}
int dfs(int u)
{
    int num = 1;
    st.push(u);
    for (auto v : vec[u])
    {
        if (v == fa[u][0]) continue;
        fa[v][0] = u; dep[v] = dep[u] + 1;
        num += dfs(v);
        if (num >= block_size) stpop(num);
    }
    return num;
}
int get_lca(int u, int v)
{
    if (dep[u] < dep[v]) swap(u, v);
    int d = dep[u] - dep[v];
    for (int i = 19; i >= 0; --i)
        if (d & (1 << i))
            u = fa[u][i];
    if (u == v) return u;
    for (int i = 19; i >= 0; --i)
        if (fa[u][i] != fa[v][i])
        {
            u = fa[u][i];
            v = fa[v][i];
        }
    return fa[u][0];
}
void init()
{
    block_size = pow(n, 2.0 / 3); block_num = 0;
    for (int i = 1; i <= m; ++i) cin >> cv[i];
    for (int i = 1; i <= n; ++i) cin >> cw[i];
    for (int i = 1; i <= n; ++i) vec[i].clear();
    for (int i = 1; i <  n; ++i)
    {
        int u, v; cin >> u >> v;
        vec[u].push_back(v);
        vec[v].push_back(u);
    }
    fa[1][0] = 0; dep[1] = 1;
    int num = dfs(1);
    if (num != 0) stpop(num);
    assert(st.empty());
    for (int i = 1; i < 20; ++i)
        for (int j = 1; j <= n; ++j)
            fa[j][i] = fa[fa[j][i - 1]][i - 1];
    for (int i = 1; i <= n; ++i) cin >> a[i], b[i] = a[i];
    tot1 = tot2 = 0;
    for (int i = 1; i <= q; ++i)
    {
        int t, x, y; cin >> t >> x >> y;
        if (t == 0)
        {
            q2[++tot2] = query2(x, b[x], y);
            b[x] = y;
        }
        else
        {
            ++tot1;
            q1[tot1] = query1(x, y, tot2, tot1);
        }
    }
    sort(q1 + 1, q1 + tot1 + 1);    
    for (int i = 1; i <= n; ++i) vis[i] = 0;
    for (int i = 1; i <= m; ++i) cnt[i] = 0;
}
void update(int pos)
{
    if (vis[pos])
    {   
        sum -= (long long)cw[cnt[a[pos]]] * cv[a[pos]];
        cnt[a[pos]]--;
    }
    else
    {
        cnt[a[pos]]++;
        sum += (long long)cw[cnt[a[pos]]] * cv[a[pos]];
    }
    vis[pos] ^= 1;
}
void change(int pos, int x)
{
    if (vis[pos])
    {
        update(pos);
        a[pos] = x;
        update(pos);
    }
    else a[pos] = x;
}
void work(int u, int v)
{
    int lca = get_lca(u, v);
    while (u != lca)
    {
        update(u);
        u = fa[u][0];
    }
    while (v != lca)
    {
        update(v);
        v = fa[v][0];
    }
}
void solve()
{
    sum = 0;
    for (int i = 1; i <= q1[1].x; ++i) change(q2[i].pos, q2[i].y);
    work(q1[1].u, q1[1].v);
    update(get_lca(q1[1].u, q1[1].v));
    ans[q1[1].id] = sum;
    update(get_lca(q1[1].u, q1[1].v));
    for (int i = 2, u = q1[1].u, v = q1[1].v, x = q1[1].x; i <= tot1; u = q1[i].u, v = q1[i].v, x = q1[i].x, ++i)
    {
        for (int j = x + 1; j <= q1[i].x; ++j) change(q2[j].pos, q2[j].y);
        for (int j = x; j >= q1[i].x + 1; --j) change(q2[j].pos, q2[j].x);
        work(u, q1[i].u);
        work(v, q1[i].v);
        update(get_lca(q1[i].u, q1[i].v));
        ans[q1[i].id] = sum;
        update(get_lca(q1[i].u, q1[i].v));  
    }
    for (int i = 1; i <= tot1; ++i) cout << ans[i] << endl;
}
int main()
{   
    cin.sync_with_stdio(0);
    while (cin >> n >> m >> q)
    {
        init();
        solve();
    }
}
相關文章
相關標籤/搜索