#565. 「LibreOJ Round #10」mathematican 的二進制(指望 + 分治NTT)

題面

戳這裏,題意簡單易懂c++

題解

首先咱們發現,操做是能夠不考慮順序的,由於每次操做會加一個 \(1\) ,每次進位會減小一個 \(1\) ,咱們就能夠考慮最後 \(1\) 的個數(也就是最後的和),以及成功操做次數,就好了。git

而後根據指望的線性性,咱們能夠從低到高按位考慮貢獻。函數

考慮一個遞推:\(f(i, j)\) 表示從後往前第 \(i\) 位總共被改變 \(j\) 次的機率,那麼有兩種轉移:優化

  • 進位:\(\displaystyle f(i - 1, j) \to f(i, \lfloor \frac j 2 \rfloor)\)
  • 操做:對於第 \(i\) 位每一個機率爲 \(p\) 的操做, \(f(i, j - 1)p + f(i, j)(1 - p) \to f(i, j)\)

注意要先轉移第一個,再轉移第二個,由於第二個能夠對於第一個操做上來的數位進行貢獻。spa

而後這樣直接實現就是 \(O(m^2)\) 的,可是咱們能夠考慮優化,對於第二個容易觀察就是乘上了 \(P(x) = \prod_i p_ix + (1-p_i)\) 這個生成函數。debug

顯然這個咱們能夠利用 分治 \(NTT\) 來解決。code

進位的話咱們能夠暴力進位,由於每一個操做對於 分治 \(NTT\) 的貢獻能夠放縮成一個等比級數:\(\displaystyle \sum_{i = 0} ^ {\infty} 2^{-i} = 2\)get

因此最後時間複雜度就是 \(O(m \log^2 m)\) 的。it

總結

對於一些指望題,能夠考慮指望的線性性,以及試試操做順序是否不影響答案。io

而後考慮 \(NTT\) 優化機率生成函數就行啦。

代碼

#include <bits/stdc++.h>

#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)
#define pb push_back

using namespace std;

typedef long long ll;

template<typename T> inline bool chkmin(T &a, T b) {return b < a ? a = b, 1 : 0;}
template<typename T> inline bool chkmax(T &a, T b) {return b > a ? a = b, 1 : 0;}

inline int read() {
    int x(0), sgn(1); char ch(getchar());
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
    for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
    return x * sgn;
}

void File() {
#ifdef zjp_shadow
    freopen ("565.in", "r", stdin);
    freopen ("565.out", "w", stdout);
#endif
}

const int N = 200100, Mod = 998244353;

ll fpm(ll x, int power) {
    ll res = 1;
    for (; power; power >>= 1, (x *= x) %= Mod)
        if (power & 1) (res *= x) %= Mod;
    return res;
}

inline int Add(int a, int b) { return (a += b) >= Mod ? a - Mod : a; }

template<int Maxn>
struct Number_Theoretic_Transfrom {

    const int g = 3;

    ll powg[Maxn + 5], invpowg[Maxn + 5]; int rev[Maxn + 5];

    void NTT_Init() {
        for (int i = 2; i <= Maxn; i <<= 1)
            invpowg[i] = fpm((powg[i] = fpm(g, (Mod - 1) / i)), Mod - 2);
    }

    int len;
    void NTT(ll P[], int opt) {
        For (i, 0, len - 1) 
            if (i < rev[i]) swap(P[i], P[rev[i]]);
        for (int i = 2, p = 1; i <= len; p = i, i <<= 1) {
            ll Wi = opt == 1 ? powg[i] : invpowg[i];
            for (int j = 0; j < len; j += i) {
                ll x = 1;
                For (k, 0, p - 1) {
                    ll u = P[j + k], v = P[j + k + p] * x % Mod;
                    P[j + k] = Add(u, v);
                    P[j + k + p] = Add(u, Mod - v);
                    (x *= Wi) %= Mod;
                }
            }
        }
        if (!~opt) {
            ll invn = fpm(len, Mod - 2);
            For (i, 0, len - 1) (P[i] *= invn) %= Mod;
        }
    }

    ll A[Maxn + 5], B[Maxn + 5];
    void Mult(int *a, int *b, int *c, int lena, int lenb) {
        int cnt = 0;
        for (len = 1; len <= lena + lenb; len <<= 1) ++ cnt;
        For (i, 0, len - 1) 
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
        For (i, 0, len - 1)
            A[i] = i <= lena ? a[i] : 0, B[i] = i <= lenb ? b[i] : 0;
        NTT(A, 1); NTT(B, 1); For (i, 0, len - 1) (A[i] *= B[i]) %= Mod; NTT(A, -1);
        For (i, 0, lena + lenb) c[i] = A[i];
    }

};

Number_Theoretic_Transfrom<1 << 20> NTT;

int n, m;

int pool[N << 5], *ptr = pool;

struct Poly {

    int *a, len;

    Poly(int l) { a = ptr; ptr += (len = l) + 1; }

    void Out() {
        debug(len);
        For (i, 1, n) printf ("%d%c", a[i], i == iend ? '\n' : ' ');
    }

};

inline bool operator < (const Poly &lhs, const Poly &rhs) {
    return lhs.len > rhs.len;
}

vector<int> Base[N];

Poly I(0);
Poly Calc(int id) {
    if (!(bool)Base[id].size()) return I;

    priority_queue<Poly> P;
    for (int prob : Base[id]) {
        Poly cur(1);
        cur.a[1] = prob;
        cur.a[0] = Mod + 1 - prob;
        P.push(cur);
    }

    while (P.size() > 1) {
        Poly a = P.top(); P.pop();
        Poly b = P.top(); P.pop();
        Poly c(a.len + b.len);
        NTT.Mult(a.a, b.a, c.a, a.len, b.len);
        P.push(c);
    }

    return P.top();
}

int main () {

    File();

    NTT.NTT_Init();

    n = read(); m = read();
    For (i, 1, m) {
        int pos = read(), x = read(), y = read();
        Base[pos].pb(1ll * x * fpm(y, Mod - 2) % Mod);
    }

    ll ans = 0;

    Poly res(0); res.a[0] = 1;
    For (i, 0, n + 20) {
        Poly tmp = Calc(i);
        if (tmp.len) {
            NTT.Mult(res.a, tmp.a, res.a, res.len, tmp.len);
            res.len = res.len + tmp.len;
        }

        For (j, 0, res.len) {
            ans = (ans + 1ll * res.a[j] * j) % Mod;
            int temp = res.a[j]; res.a[j] = 0;
            (res.a[j >> 1] += temp) %= Mod;
        }
        res.len >>= 1;
    }

    printf ("%lld\n", ans);

    return 0;

}
相關文章
相關標籤/搜索