淺談線段樹中加與乘標記的下放

感受題解裏面對加和乘標記下放的順序講的不是很清楚,要麼是直接沒說,要麼是一句話帶過node

若是想看1080P高清無碼證實的能夠報洛谷冬令營省選班,去看第一天的回放233ui


假設咱們一個節點爲\([val,mul,add]\),其中\(val\)表明該節點的權值,\(mul\)爲乘法標記,\(add\)爲加法標記spa

那麼咱們有兩種表示方式,

code

  • 第一種:先加再乘

此時該節點爲\((val+add)*mul\)get

當再遇到一個\([\_mul,\_add]\)的標記時,it

此時節點爲\([(val+add)*mul+\_add]*\_mul\)io

把式子展開並從新化爲\((val+add')*mul'\)的形式
(也就是提出\(mul*\_mul\)這一項)得class

\((val+add+\frac{\_add}{mul})*mul*\_mul\)date

咱們發現這裏有個除法,會損失不少精度di

所以咱們換一個思路


  • 第二種:先乘再加

此時該節點爲\((val*mul)+add\)

當再遇到一個\([\_mul,\_add]\)的標記時,

此時節點爲\([(val*mul)+add]*\_mul+\_add\)

把式子展開並從新化爲\((val*mul')+add'\)的形式

\(val*mul*\_mul+add*\_mul+\_add\)

咱們發現這樣不須要除法,所以咱們選用第二種




其實線段樹標記的下放通常都是這個套路

建議你們作完這道題後再去作一下這道題


放一下醜陋的代碼

#include<cstdio>
#include<cmath>
#include<algorithm>
#define ls k<<1
#define rs k<<1|1
#define int long long
using namespace std;
const int MAXN = 1e6 + 10;
inline int read() {
    char c = getchar(); int x = 0, f = 1;
    while (c < '0' || c > '9') {if (c == '-')f = -1; c = getchar();}
    while (c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();}
    return x * f;
}
int N, M, mod;
struct node {
    int mul, add, sum, l, r, siz;
} T[MAXN];
void update(int k) {
    T[k].sum = (T[ls].sum % mod + T[rs].sum % mod) % mod;
}
void ps(int x, int f) {
    T[x].mul = (T[x].mul % mod * T[f].mul % mod) % mod;
    T[x].add = (T[x].add * T[f].mul) % mod;
    T[x].add = (T[x].add + T[f].add) % mod;
    T[x].sum = (T[x].sum % mod * T[f].mul % mod) % mod;
    T[x].sum = (T[x].sum + T[f].add % mod * T[x].siz) % mod;
}
void pushdown(int k) {
    if (T[k].add == 0 && T[k].mul == 1) return ;
    ps(ls, k);
    ps(rs, k);
    T[k].add = 0;
    T[k].mul = 1;
}
void Build(int k, int ll, int rr) {
    T[k].l = ll; T[k].r = rr; T[k].siz = rr - ll + 1; T[k].mul = 1;
    if (ll == rr) {
        T[k].sum = read() % mod;
        return ;
    }
    int mid = ll + rr >> 1;
    Build(ls, ll, mid);
    Build(rs, mid + 1, rr);
    update(k);
}
void IntervalMul(int k, int ll, int rr, int val) {
    if (ll <= T[k].l && T[k].r <= rr) {
        T[k].sum = (T[k].sum * val) % mod;
        T[k].mul = (T[k].mul * val) % mod;
        T[k].add = (T[k].add * val) % mod;
        return ;
    }
    pushdown(k);
    int mid = T[k].l + T[k].r >> 1;
    if (ll <= mid) IntervalMul(ls, ll, rr, val);
    if (rr > mid)  IntervalMul(rs, ll, rr, val);
    update(k);
}
void IntervalAdd(int k, int ll, int rr, int val) {
    if (ll <= T[k].l && T[k].r <= rr) {
        T[k].sum = (T[k].sum + T[k].siz * val) % mod;
        T[k].add = (T[k].add + val) % mod;
        return ;
    }
    pushdown(k);
    int mid = T[k].l + T[k].r >> 1;
    if (ll <= mid) IntervalAdd(ls, ll, rr, val);
    if (rr > mid)  IntervalAdd(rs, ll, rr, val);
    update(k);
}
int IntervalSum(int k, int ll, int rr) {
    int ans = 0;
    if (ll <= T[k].l && T[k].r <= rr) {
        ans = (ans + T[k].sum) % mod;
        return ans;
    }
    pushdown(k);
    int mid = T[k].l + T[k].r >> 1;
    if (ll <= mid) ans = (ans + IntervalSum(ls, ll, rr)) % mod;
    if (rr > mid)  ans = (ans + IntervalSum(rs, ll, rr)) % mod;
    return ans % mod;
}
main() {
#ifdef WIN32
    freopen("a.in", "r", stdin);
#endif
    N = read(); M = read(); mod = read();
    Build(1, 1, N);
    while (M--) {
        int opt = read();
        if (opt == 1) {
            int l = read(), r = read(), val = read() % mod;
            IntervalMul(1, l, r, val);
        } else if (opt == 2) {
            int l = read(), r = read(), val = read() % mod;
            IntervalAdd(1, l, r, val);
        } else if (opt == 3) {
            int l = read(), r = read();
            printf("%lld\n", IntervalSum(1, l, r) % mod);
        }
    }
    return 0;
}
相關文章
相關標籤/搜索