Problem :
給一棵n個點的樹,每一個點上有一種顏色,對於一條路徑上的點,若 i 顏色第 j 次出現對該路徑權值的貢獻爲 w[i] * c[j], 每次詢問一條路徑的權值,或者修改某個點的顏色。
Solution :
樹上的帶修改的莫隊。
使用dfs序來對左右端點進行分塊。
第一關鍵字分塊排序左端點,第二關鍵字分塊排序右端點,第三關鍵字排序詢問順序。c++
用S(v, u)表明 v到u的路徑上的結點的集合。
用root來表明根結點,用lca(v, u)來表明v、u的最近公共祖先。
那麼
S(v, u) = S(root, v) xor S(root, u) xor lca(v, u)
其中xor是集合的對稱差。
簡單來講就是節點出現兩次消掉。spa
lca很討厭,因而再定義
T(v, u) = S(root, v) xor S(root, u)
觀察將curV移動到targetV先後T(curV, curU)變化:
T(curV, curU) = S(root, curV) xor S(root, curU)
T(targetV, curU) = S(root, targetV) xor S(root, curU)
取對稱差:
T(curV, curU) xor T(targetV, curU)= (S(root, curV) xor S(root, curU)) xor (S(root, targetV) xor S(root, curU))
因爲對稱差的交換律、結合律:
T(curV, curU) xor T(targetV, curU)= S(root, curV) xor S(root, targetV)
兩邊同時xor T(curV, curU):
T(targetV, curU)= T(curV, curU) xor S(root, curV) xor S(root, targetV)
發現最後兩項很爽……哇哈哈
T(targetV, curU)= T(curV, curU) xor T(curV, targetV)
(有公式恐懼症的不要走啊 T_T)code
也就是說,更新的時候,xor T(curV, targetV)就好了。
即,對curV到targetV路徑(除開lca(curV, targetV))上的結點,將它們的存在性取反便可。排序
引用自vfkci
#include <bits/stdc++.h> using namespace std; const int INF = 1e9 + 7; const int N = 1000008; int n, m, q, block_size, block_num; vector <int> vec[N]; int cv[N], cw[N], cnt[N], a[N], b[N], vis[N]; int belong[N], num, tot1, tot2; int fa[N][20], dep[N]; stack <int> st; long long ans[N], sum; struct query1 { int u, v, ub, vb, x, id; query1(){} query1(int u, int v, int x, int id) : u(u), v(v), x(x), id(id) { ub = belong[u]; vb = belong[v]; } bool operator < (const query1 &b) const { if (ub != b.ub) return ub < b.ub; if (vb != b.vb) return vb < b.vb; return x < b.x; } }q1[N]; struct query2 { int pos, x, y; query2(){} query2(int pos, int x, int y) : pos(pos), x(x), y(y){} }q2[N]; void stpop(int &num) { ++block_num; for (int i = 1; i <= num; ++i) { int p = st.top(); st.pop(); belong[p] = block_num; } num = 0; } int dfs(int u) { int num = 1; st.push(u); for (auto v : vec[u]) { if (v == fa[u][0]) continue; fa[v][0] = u; dep[v] = dep[u] + 1; num += dfs(v); if (num >= block_size) stpop(num); } return num; } int get_lca(int u, int v) { if (dep[u] < dep[v]) swap(u, v); int d = dep[u] - dep[v]; for (int i = 19; i >= 0; --i) if (d & (1 << i)) u = fa[u][i]; if (u == v) return u; for (int i = 19; i >= 0; --i) if (fa[u][i] != fa[v][i]) { u = fa[u][i]; v = fa[v][i]; } return fa[u][0]; } void init() { block_size = pow(n, 2.0 / 3); block_num = 0; for (int i = 1; i <= m; ++i) cin >> cv[i]; for (int i = 1; i <= n; ++i) cin >> cw[i]; for (int i = 1; i <= n; ++i) vec[i].clear(); for (int i = 1; i < n; ++i) { int u, v; cin >> u >> v; vec[u].push_back(v); vec[v].push_back(u); } fa[1][0] = 0; dep[1] = 1; int num = dfs(1); if (num != 0) stpop(num); assert(st.empty()); for (int i = 1; i < 20; ++i) for (int j = 1; j <= n; ++j) fa[j][i] = fa[fa[j][i - 1]][i - 1]; for (int i = 1; i <= n; ++i) cin >> a[i], b[i] = a[i]; tot1 = tot2 = 0; for (int i = 1; i <= q; ++i) { int t, x, y; cin >> t >> x >> y; if (t == 0) { q2[++tot2] = query2(x, b[x], y); b[x] = y; } else { ++tot1; q1[tot1] = query1(x, y, tot2, tot1); } } sort(q1 + 1, q1 + tot1 + 1); for (int i = 1; i <= n; ++i) vis[i] = 0; for (int i = 1; i <= m; ++i) cnt[i] = 0; } void update(int pos) { if (vis[pos]) { sum -= (long long)cw[cnt[a[pos]]] * cv[a[pos]]; cnt[a[pos]]--; } else { cnt[a[pos]]++; sum += (long long)cw[cnt[a[pos]]] * cv[a[pos]]; } vis[pos] ^= 1; } void change(int pos, int x) { if (vis[pos]) { update(pos); a[pos] = x; update(pos); } else a[pos] = x; } void work(int u, int v) { int lca = get_lca(u, v); while (u != lca) { update(u); u = fa[u][0]; } while (v != lca) { update(v); v = fa[v][0]; } } void solve() { sum = 0; for (int i = 1; i <= q1[1].x; ++i) change(q2[i].pos, q2[i].y); work(q1[1].u, q1[1].v); update(get_lca(q1[1].u, q1[1].v)); ans[q1[1].id] = sum; update(get_lca(q1[1].u, q1[1].v)); for (int i = 2, u = q1[1].u, v = q1[1].v, x = q1[1].x; i <= tot1; u = q1[i].u, v = q1[i].v, x = q1[i].x, ++i) { for (int j = x + 1; j <= q1[i].x; ++j) change(q2[j].pos, q2[j].y); for (int j = x; j >= q1[i].x + 1; --j) change(q2[j].pos, q2[j].x); work(u, q1[i].u); work(v, q1[i].v); update(get_lca(q1[i].u, q1[i].v)); ans[q1[i].id] = sum; update(get_lca(q1[i].u, q1[i].v)); } for (int i = 1; i <= tot1; ++i) cout << ans[i] << endl; } int main() { cin.sync_with_stdio(0); while (cin >> n >> m >> q) { init(); solve(); } }