[GXOI/GZOI2019]舊詞(樹上差分+樹剖)

前置芝士:[LNOI2014]LCA

要是這題放HNOI就行了c++

原題:\(\sum_{l≤i≤r}dep[LCA(i,z)]\)算法

這題:\(\sum_{i≤r}dep[LCA(i,z)]^k\)優化

對於原題,咱們須要把每一個詢問拆成1~l-1 & 1~r再進行差分(因此這題幫咱們省去了一個步驟ui

先考慮\(k=1\)原題

咱們先轉化題意spa

\(dep[lca]\)\(\\)==\(\\)\(dis[1][lca]+1\)\(\\)==\(\\)\(lca->1\)的點數debug

因此咱們每個點(x)對答案的貢獻(\(dep[lca(x, z)]\)),就是他們到根節點的公共路徑的點數code

因而,對於每個點,咱們只須要把1->x的鏈上加1便可get

對於每個詢問,咱們只須要求出1->z的鏈上的和便可it

這一點咱們能夠利用樹剖\(/LCT\)解決class

可是直接作是\(O(N^2*\)樹剖\(\LCT)\)的,咱們考慮莫隊

這樣複雜度變成了\(O(N\sqrt{N}*\)樹剖\(\LCT)\)

什麼?你以爲這個算法還不夠優秀?因此咱們來考慮優化莫隊

莫隊的\(\sqrt{N}\)是怎麼來的?不停的移動左右端點

可是這道題的左端點是固定的\((1)\),因此只須要移動右端點便可,而右端點不須要動來動去,只須要日後掃一遍便可,複雜度是\(O(N*\)樹剖\(\LCT)\)

代碼的話能夠參考[LNOI2014]LCA

考慮k!=1

咱們爲何k=1的時候對於每一個點是\(1->x\)路徑上+1?

這個1的本質是樹上差分,即:\((dep[x]+1)^1-dep[x]^1 = 1\)

因此咱們只須要把1改爲k便可

因此如今問題變成了:給定一個序列,每個點有兩個權值\((a, b)\),每個點的點權爲\(a*b\),支持a權值區間加1和區間查詢

由於b不會改變,因此咱們考慮線段樹

把線段樹的每個節點新弄一個權值,爲\(\sum_{l≤i≤r} b\),每次更新區間的時候用這個權值*sum便可

#include<bits/stdc++.h>
using namespace std;
#define il inline
#define re register
#define debug printf("Now is Line : %d\n",__LINE__)
#define file(a) freopen(#a".in","r",stdin);freopen(#a".out","w",stdout)
#define int long long
#define inf 123456789
#define mod 998244353
il int read() {
    re int x = 0, f = 1; re char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
    return x * f;
}
#define rep(i, s, t) for(re int i = s; i <= t; ++ i)
#define Next(i, u) for(re int i = head[u]; i; i = e[i].next)
#define mem(k, p) memset(k, p, sizeof(k))
#define ls k * 2
#define rs k * 2 + 1
#define _ 500005
struct edge {int v, next;}e[_];
struct ques {
    int u, z, id;
    il bool operator < (const ques x) const {return u < x.u;}
}q[_];
int n, m, k, val[_ << 2], sum[_ << 2], mi[_], head[_], cnt, ans[_], tag[_ << 2], now;
int fa[_], dep[_], size[_], son[_], top[_], seg[_], col, rev[_];
il void add(int u, int v) {
    e[++ cnt] = (edge) {v, head[u]}, head[u] = cnt;
}
il int qpow(int a, int b) {
    int r = 1;
    while(b) {
        if(b & 1) r = 1ll * r * a % mod;
        b >>= 1, a = 1ll * a * a % mod;
    }
    return r;
}
il void dfs1(int u) {
    dep[u] = dep[fa[u]] + 1, size[u] = 1;
    Next(i, u) {
        if(e[i].v == fa[u]) continue;
        dfs1(e[i].v), size[u] += size[e[i].v];
        if(size[son[u]] < size[e[i].v]) son[u] = e[i].v;
    }
}
il void dfs2(int u, int fr) {
    top[u] = fr, seg[u] = ++ col, rev[col] = u;
    if(son[u]) dfs2(son[u], fr);
    Next(i, u) if(e[i].v != son[u] && e[i].v != fa[u]) dfs2(e[i].v, e[i].v);
}
il void build(int k, int l, int r) {
    if(l == r) return (void)(val[k] = (mi[dep[rev[l]]] - mi[dep[rev[l]] - 1] + mod) % mod);
    int mid = (l + r) >> 1;
    build(ls, l, mid), build(rs, mid + 1, r), val[k] = (val[ls] + val[rs]) % mod;
}
il void pushdown(int k) {
    if(!tag[k]) return;
    sum[ls] = (sum[ls] + ((tag[k] * val[ls]) % mod)) % mod;
    sum[rs] = (sum[rs] + ((tag[k] * val[rs]) % mod)) % mod;
    tag[ls] += tag[k], tag[rs] += tag[k], tag[k] = 0;
}
il void change(int k, int l, int r, int ll, int rr) {
    if(l > rr || ll > r) return;
    if(l >= ll && r <= rr) {sum[k] = (sum[k] + val[k]) % mod, ++ tag[k]; return;}
    int mid = (l + r) >> 1;
    pushdown(k), change(ls, l, mid, ll, rr), change(rs, mid + 1, r, ll, rr);
    sum[k] = (sum[ls] + sum[rs]) % mod;
}
il int query(int k, int l, int r, int ll, int rr) {
    if(l > rr || ll > r) return 0;
    if(l >= ll && r <= rr) return sum[k];
    int mid = (l + r) >> 1;
    pushdown(k);
    return (query(ls, l, mid, ll, rr) + query(rs, mid + 1, r, ll, rr)) % mod;
}
il int query(int u) {
    int ans = 0;
    while(top[u]) ans = (ans + query(1, 1, n, seg[top[u]], seg[u])) % mod, u = fa[top[u]];
    return ans;
}
il void change(int u) {
    while(top[u]) change(1, 1, n, seg[top[u]], seg[u]), u = fa[top[u]];
}
signed main() {
    n = read(), m = read(), k = read() % (mod - 1), now = 1;
    rep(i, 1, n) mi[i] = qpow(i, k);
    rep(i, 2, n) fa[i] = read(), add(fa[i], i);
    rep(i, 1, m) q[i].id = i, q[i].u = read(), q[i].z = read();
    sort(q + 1, q + m + 1), dfs1(1), dfs2(1, 1), build(1, 1, n);
    rep(i, 1, n) {
        change(i);
        while(i == q[now].u) ans[q[now].id] = query(q[now].z), ++ now;
    }
    rep(i, 1, m) printf("%lld\n", ans[i]);
    return 0;
}
相關文章
相關標籤/搜索