【CF1218E】Product Tuples

題目大意:給定一個長度爲 \(N\) 的序列,求從序列中選出 \(K\) 個數的集合乘積之和是多少。ios

題解:
因爲是選出 \(K\) 個數字組成的集合,可知對於要計算的 \(K\) 元組來講是沒有標號的,而元組是由序列中 \(N\) 個數字組合而成的。所以,將要求的元組看做組合對象,該組合對象是由 \(N\) 個不一樣種類的組合對象組成的,且組合對象是沒有標號的,所以採用普通生成函數計算便可。
對於第 \(i\) 個數的普通生成函數爲 \[(1 + a_ix)\],所以,原組合對象的生成函數是\[\prod\limits_{i = 1}^{n}(1+a_ix)\]。能夠經過分治乘法來進行計算,時間複雜度爲 \(O(nlogn)\)c++

代碼以下函數

#include <bits/stdc++.h>

using namespace std;

typedef long long LL;

const int mod = 998244353, g = 3, ig = 332748118;

LL fpow(LL a, LL b, LL c) {
  LL ret = 1 % c;
  for (; b; b >>= 1, a = a * a % mod) if (b & 1) ret = ret * a % mod;
  return ret;
}

void ntt(vector<LL> &v, vector<int> &rev, int opt) {
  int tot = v.size();
  for (int i = 0; i < tot; i++) if (i < rev[i]) swap(v[i], v[rev[i]]);
  for (int mid = 1; mid < tot; mid <<= 1) {
    LL wn = fpow(opt == 1 ? g : ig, (mod - 1) / (mid << 1), mod);
    for (int j = 0; j < tot; j += mid << 1) {
      LL w = 1;
      for (int k = 0; k < mid; k++) {
        LL x = v[j + k], y = v[j + mid + k] * w % mod;
        v[j + k] = (x + y) % mod, v[j + mid + k] = (x - y + mod) % mod;
        w = w * wn % mod;
      }
    }
  }
  if (opt == -1) {
    LL itot = fpow(tot, mod - 2, mod);
    for (int i = 0; i < tot; i++) v[i] = v[i] * itot % mod;
  }
}

vector<LL> convolution(vector<LL> &a, int cnta, vector<LL> &b, int cntb, const function<LL(LL, LL)> &calc) {
  int bit = 0, tot = 1;
  while (tot <= 2 * max(cnta, cntb)) bit++, tot <<= 1;
  vector<int> rev(tot);
  for (int i = 0; i < tot; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << (bit - 1);
  vector<LL> foo(tot), bar(tot);
  for (int i = 0; i < cnta; i++) foo[i] = a[i];
  for (int i = 0; i < cntb; i++) bar[i] = b[i];
  ntt(foo, rev, 1), ntt(bar, rev, 1);
  for (int i = 0; i < tot; i++) foo[i] = calc(foo[i], bar[i]);
  ntt(foo, rev, -1);
  return foo;
}

int main() {
  //freopen("data.in", "r", stdin);
  ios::sync_with_stdio(false);
  cin.tie(0), cout.tie(0);
  int n, K;
  cin >> n >> K;
  vector<LL> a(n);
  for (int i = 0; i < n; i++) {
    cin >> a[i];
  }
  int m;
  cin >> m;
  while (m--) {
    int opt;
    cin >> opt;
    vector<LL> b = a;
    if (opt == 1) {
      int q, x, y;
      cin >> q >> x >> y;
      x--;
      b[x] = y;
      for (int i = 0; i < n; i++) {
        b[i] = (q - b[i] + mod) % mod;
      }
    } else {
      int q, l, r, d;
      cin >> q >> l >> r >> d;
      l--, r--;
      for (int i = l; i <= r; i++) {
        b[i] = (b[i] + d) % mod;
      }
      for (int i = 0; i < n; i++) {
        b[i] = (q - b[i] + mod) % mod;
      }
    }
    function<vector<LL>(int, int)> solve = [&](int l, int r) {
      if (l == r) {
        return vector<LL> {1, b[l]};
      }
      int mid = l + r >> 1;
      vector<LL> ls = solve(l, mid);
      vector<LL> rs = solve(mid + 1, r);
      return convolution(ls, mid - l + 2, rs, r - mid + 1, [&](LL a, LL b) {
        return a * b % mod;
      });
    };
    vector<LL> ans = solve(0, n - 1);
    cout << ans[K] << endl;
  }
  return 0;
}
相關文章
相關標籤/搜索