[ BZOJ 3451 ] Normal

Description

題目連接ios

定義一次點分治的複雜度是全部分治中心分治時的子樹大小之和。c++

給定一棵樹,問全部點等機率被選作重心,點分治的指望複雜度。git

Solution

根據指望的線性性,答案等價於每一個點在點分樹上的深度指望之和。數組

思路是從點對的角度考慮某一個點是否會產生貢獻。
\[ E(depth[x])=\sum_{y=1}^n P(x\in subtree[y]) \]ide

也就是 \(x\) 在點分樹上在 \(1\dots n\) 的子樹中的機率和。spa

考慮點分樹上 \(y\)\(x\) 的祖先的條件,要求 \(x\)\(y\) 構成的這條鏈上第一個在點分治過程當中被刪除的點是 \(y\) ,因爲鏈上被選中的機率相等,所以這個機率爲 \(\frac{1}{dist(x,y) + 1}\)code

因此答案爲
\[ \sum_{x=1}^n\sum_{j=1}^n \frac{1}{dis(i,j) + 1}=\sum_{len = 0}^n \frac{cnt[i]}{i + 1} \]排序

所以須要點分治求長度爲 \(i\) 的路徑條數 \(cnt[i]\) ,注意到合併的時候是卷積的形式。ip

容斥作法

不考慮重複路徑,把子樹 dfs 一遍,直接本身進行卷積,再去掉子樹內重複計數的路徑便可。get

每一層最差以本身的 \(size\) 做爲長度進行卷積,所以複雜度爲 \(\mathcal O(n\log^2 n)\)

#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 65537
#define mod 998244353
using namespace std;
typedef long long ll;
 
inline int rd() {
  int x = 0;
  char c = getchar();
  while (!isdigit(c)) c = getchar();
  while (isdigit(c)) {
    x = x * 10 + (c ^ 48); c = getchar();
  }
  return x;
}
 
inline void print(ll x) {
  int y = 10, len = 1;
  while(y <= x) {y *= 10; ++len;}
  while(len--) {y /= 10; putchar(x / y + 48); x %= y;}
  putchar('\n');
}
 
inline int fpow(int x, int t = mod - 2) {
  int res = 1;
  while (t) {
    if (t & 1) res = 1ll * res * x % mod;
    x = 1ll * x * x % mod; t >>= 1;
  }
  return res;
}
 
int mxlen = (1 << 16), w[2][N], rev[N];
 
inline int mo(int x) {
  return x >= mod ? x - mod : x;
}
 
inline void init() {
  int per = fpow(3, (mod - 1) / mxlen);
  int invper = fpow(per);
  w[0][0] = w[1][0] = 1;
  for (int i = 1; i < mxlen; ++i) {
    w[0][i] = 1ll * w[0][i - 1] * per % mod;
    w[1][i] = 1ll * w[1][i - 1] * invper % mod;
  }
}
 
inline int Rev(int n) {
  int len = 1, bit = 0;
  while (len <= n) len <<= 1, ++bit;
  for (int i = 0; i < len; ++i)
    rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)));
  return len;
}
 
inline void NTT(int *f, int len, int o) {
  for (int i = 0; i < len; ++i)
    if (i > rev[i]) swap(f[i], f[rev[i]]);
  for (int i = 1; i < len; i <<= 1) {
    int wn = mxlen / (i << 1);
    for (int j = 0; j < len; j += (i << 1)) {
      int nw = 0, x, y;
      for (int k = 0; k < i; ++k, nw += wn) {
        x = f[j + k];
        y = 1ll * w[o][nw] * f[i + j + k] % mod;
        f[j + k] = mo(x + y);
        f[i + j + k] = mo(x - y + mod);
      }
    }
  }
  if (o == 1) {
    int invl = fpow(len);
    for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod;
  }
}
 
bool vis[N];
 
int n, m, tot, totn, mx, rt, mxd;
 
int bkt[N], cnt[N], sz[N], hd[N];
 
struct edge{int to, nxt;} e[N << 1];
 
inline void add(int u, int v) {
  e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot;
  e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot;
}
 
void getrt(int u, int fa) {
  sz[u] = 1;
  int mxs = 0;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) {
      getrt(v, u);
      sz[u] += sz[v];
      mxs = max(mxs, sz[v]);
    }
  mxs = max(mxs, totn - sz[u]);
  if (mxs < mx) {mx = mxs; rt = u;}
}
 
void getsz(int u, int fa) {
  sz[u] =  1;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) {
      getsz(v, u); sz[u] += sz[v];
    }
}
 
void dfs(int u, int fa, int dep) {
  ++bkt[dep]; mxd = max(mxd, dep);
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) dfs(v, u, dep + 1);
}
 
inline void mul(int *a, int len, int o) {
  len = Rev(len << 1);
  NTT(a, len, 0);
  for (int i = 0; i < len; ++i) a[i] = 1ll * a[i] * a[i] % mod;
  NTT(a, len, 1);
  if (o > 0) for (int i = 0; i < len; ++i) cnt[i + 1] += a[i];
  else for (int i = 0; i < len; ++i) cnt[i + 3] -= a[i];
  for (int i = 0; i < len; ++i) a[i] = 0;
}
 
inline void calc(int u, int o) {
  mxd = 0;
  dfs(u, 0, 0);
  mul(bkt, mxd, o);
}
 
void divide(int u) {
  vis[u] = 1;
  calc(u, 1);
  for (int i = hd[u], v; i; i = e[i].nxt)
    if (!vis[v = e[i].to]) {
      calc(v, -1);
      getsz(v, u);
      totn = mx = sz[v]; rt = v;
      getrt(v, 0); divide(rt);
    }
}
 
int main() {
  init();
  n = rd();
  for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1);
  mx = totn = n;
  getrt(1, 0); divide(rt);
  double ans = 0.0;
  for (int i = 1; i <= n + 1; ++i) ans += (double) cnt[i] / i;
  printf("%.4lf", ans);
  return 0;
}

子樹按秩合併作法

在點分治求路徑條數時,咱們嘗試用按秩合併的思路去搞,也就是將子樹按照最深深度排序,而後逐個合併計算答案。

開始的時候只有 \(bkt[0]=1\),而後按順序卷每個子樹求出來的計數數組 \(bktson\)

把貢獻直接計算,而後再將 \(bktson\) 按位加到 \(bkt\) 上。

考慮複雜度,將子樹按照深度從小到大排序後,每次卷積獲得的新的鏈長不會超過新合併的子樹深度的二倍,因此每次卷積的數組長度爲 \(mxdep[v]\) 的,且每一個位置只會和其父節點卷積一次,所以總複雜度爲 \(\mathcal O(n\log^2 n)\)

#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 65537
#define mod 998244353
using namespace std;
typedef long long ll;
 
inline int rd() {
  int x = 0;
  char c = getchar();
  while (!isdigit(c)) c = getchar();
  while (isdigit(c)) {
    x = x * 10 + (c ^ 48); c = getchar();
  }
  return x;
}
 
inline void print(ll x) {
  int y = 10, len = 1;
  while(y <= x) {y *= 10; ++len;}
  while(len--) {y /= 10; putchar(x / y + 48); x %= y;}
  putchar('\n');
}
 
inline int fpow(int x, int t = mod - 2) {
  int res = 1;
  while (t) {
    if (t & 1) res = 1ll * res * x % mod;
    x = 1ll * x * x % mod; t >>= 1;
  }
  return res;
}
 
int mxlen = (1 << 16), w[2][N], rev[N];
 
inline int mo(int x) {
  return x >= mod ? x - mod : x;
}
 
inline void init() {
  int per = fpow(3, (mod - 1) / mxlen);
  int invper = fpow(per);
  w[0][0] = w[1][0] = 1;
  for (int i = 1; i < mxlen; ++i) {
    w[0][i] = 1ll * w[0][i - 1] * per % mod;
    w[1][i] = 1ll * w[1][i - 1] * invper % mod;
  }
}
 
inline int Rev(int n) {
  int len = 1, bit = 0;
  while (len <= n) len <<= 1, ++bit;
  for (int i = 0; i < len; ++i)
    rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)));
  return len;
}
 
inline void NTT(int *f, int len, int o) {
  for (int i = 0; i < len; ++i)
    if (i > rev[i]) swap(f[i], f[rev[i]]);
  for (int i = 1; i < len; i <<= 1) {
    int wn = mxlen / (i << 1);
    for (int j = 0; j < len; j += (i << 1)) {
      int nw = 0, x, y;
      for (int k = 0; k < i; ++k, nw += wn) {
        x = f[j + k];
        y = 1ll * w[o][nw] * f[i + j + k] % mod;
        f[j + k] = mo(x + y);
        f[i + j + k] = mo(x - y + mod);
      }
    }
  }
  if (o == 1) {
    int invl = fpow(len);
    for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod;
  }
}
 
bool vis[N];
 
double ans = 0.0;
 
int n, m, tot, totn, mx, rt;
 
int bkt[N], sz[N], hd[N];
 
struct edge{int to, nxt;} e[N << 1];
 
inline void add(int u, int v) {
  e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot;
  e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot;
}
 
void getrt(int u, int fa) {
  sz[u] = 1;
  int mxs = 0;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) {
      getrt(v, u);
      sz[u] += sz[v];
      mxs = max(mxs, sz[v]);
    }
  mxs = max(mxs, totn - sz[u]);
  if (mxs < mx) {mx = mxs; rt = u;}
}
 
void getsz(int u, int fa) {
  sz[u] =  1;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) {
      getsz(v, u); sz[u] += sz[v];
    }
}
 
int res[N], tmp[N];
 
inline int mul(int *a, int *b, int lena, int lenb) {
  int len = Rev(lenb << 1);
  for (int i = 0; i < lena; ++i) res[i] = a[i];
  for (int i = lena; i < len; ++i) res[i] = 0;
  for (int i = 0; i < lenb; ++i) tmp[i] = b[i];
  for (int i = lenb; i < len; ++i) tmp[i] = 0;
  NTT(res, len, 0); NTT(tmp, len, 0);
  for (int i = 0; i < len; ++i) res[i] = 1ll * res[i] * tmp[i] % mod;
  NTT(res, len, 1);
  for (int i = 0; i < len; ++i) ans += 2.0 * res[i] / (i + 1);
  return len;
}
 
int mxd[N], s[N], bkts[N];
 
inline bool cmp(int x, int y) {return mxd[x] < mxd[y];}
 
int dfs(int u, int fa, int dep) {
  int resd = dep;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) resd = max(resd, dfs(v, u, dep + 1));
  return resd;
}
 
void dfs2(int u, int fa, int dep) {
  ++bkts[dep];
  for (int i = hd[u], v; i; i = e[i].nxt)
    if ((v = e[i].to) != fa && !vis[v]) dfs2(v, u, dep + 1);
}
 
void divide(int u) {
  vis[u] = 1;
  s[0] = 0;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if (!vis[v = e[i].to]) {
      s[++s[0]] = v;
      mxd[v] = dfs(v, u, 1);
    }
  sort(s + 1, s + 1 + s[0], cmp);
  bkt[0] = 1;
  int nowlen = 1;
  for (int i = 1, v; i <= s[0]; ++i) {
    dfs2(v = s[i], 0, 1);
    nowlen = mul(bkt, bkts, nowlen, mxd[v] + 1);
    for (int i = 0; i <= mxd[v]; ++i) {
      bkt[i] += bkts[i]; bkts[i] = 0;
    }
  }
  for (int i = 0; i <= nowlen; ++i) bkt[i] = 0;
  for (int i = hd[u], v; i; i = e[i].nxt)
    if (!vis[v = e[i].to]) {
      getsz(v, u);
      totn = mx = sz[v]; rt = v;
      getrt(v, 0); divide(rt);
    }
}
 
int main() {
  init();
  n = rd();
  for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1);
  mx = totn = n;
  getrt(1, 0); divide(rt);
  printf("%.4lf", ans + n);
  return 0;
}
相關文章
相關標籤/搜索