題目連接ios
定義一次點分治的複雜度是全部分治中心分治時的子樹大小之和。c++
給定一棵樹,問全部點等機率被選作重心,點分治的指望複雜度。git
根據指望的線性性,答案等價於每一個點在點分樹上的深度指望之和。數組
思路是從點對的角度考慮某一個點是否會產生貢獻。
\[ E(depth[x])=\sum_{y=1}^n P(x\in subtree[y]) \]ide
也就是 \(x\) 在點分樹上在 \(1\dots n\) 的子樹中的機率和。spa
考慮點分樹上 \(y\) 是 \(x\) 的祖先的條件,要求 \(x\) 和 \(y\) 構成的這條鏈上第一個在點分治過程當中被刪除的點是 \(y\) ,因爲鏈上被選中的機率相等,所以這個機率爲 \(\frac{1}{dist(x,y) + 1}\)。code
因此答案爲
\[ \sum_{x=1}^n\sum_{j=1}^n \frac{1}{dis(i,j) + 1}=\sum_{len = 0}^n \frac{cnt[i]}{i + 1} \]排序
所以須要點分治求長度爲 \(i\) 的路徑條數 \(cnt[i]\) ,注意到合併的時候是卷積的形式。ip
不考慮重複路徑,把子樹 dfs 一遍,直接本身進行卷積,再去掉子樹內重複計數的路徑便可。get
每一層最差以本身的 \(size\) 做爲長度進行卷積,所以複雜度爲 \(\mathcal O(n\log^2 n)\)
#include <cmath> #include <cstdio> #include <cctype> #include <cstdlib> #include <cstring> #include <iostream> #include <algorithm> #define N 65537 #define mod 998244353 using namespace std; typedef long long ll; inline int rd() { int x = 0; char c = getchar(); while (!isdigit(c)) c = getchar(); while (isdigit(c)) { x = x * 10 + (c ^ 48); c = getchar(); } return x; } inline void print(ll x) { int y = 10, len = 1; while(y <= x) {y *= 10; ++len;} while(len--) {y /= 10; putchar(x / y + 48); x %= y;} putchar('\n'); } inline int fpow(int x, int t = mod - 2) { int res = 1; while (t) { if (t & 1) res = 1ll * res * x % mod; x = 1ll * x * x % mod; t >>= 1; } return res; } int mxlen = (1 << 16), w[2][N], rev[N]; inline int mo(int x) { return x >= mod ? x - mod : x; } inline void init() { int per = fpow(3, (mod - 1) / mxlen); int invper = fpow(per); w[0][0] = w[1][0] = 1; for (int i = 1; i < mxlen; ++i) { w[0][i] = 1ll * w[0][i - 1] * per % mod; w[1][i] = 1ll * w[1][i - 1] * invper % mod; } } inline int Rev(int n) { int len = 1, bit = 0; while (len <= n) len <<= 1, ++bit; for (int i = 0; i < len; ++i) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1))); return len; } inline void NTT(int *f, int len, int o) { for (int i = 0; i < len; ++i) if (i > rev[i]) swap(f[i], f[rev[i]]); for (int i = 1; i < len; i <<= 1) { int wn = mxlen / (i << 1); for (int j = 0; j < len; j += (i << 1)) { int nw = 0, x, y; for (int k = 0; k < i; ++k, nw += wn) { x = f[j + k]; y = 1ll * w[o][nw] * f[i + j + k] % mod; f[j + k] = mo(x + y); f[i + j + k] = mo(x - y + mod); } } } if (o == 1) { int invl = fpow(len); for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod; } } bool vis[N]; int n, m, tot, totn, mx, rt, mxd; int bkt[N], cnt[N], sz[N], hd[N]; struct edge{int to, nxt;} e[N << 1]; inline void add(int u, int v) { e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot; e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot; } void getrt(int u, int fa) { sz[u] = 1; int mxs = 0; for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) { getrt(v, u); sz[u] += sz[v]; mxs = max(mxs, sz[v]); } mxs = max(mxs, totn - sz[u]); if (mxs < mx) {mx = mxs; rt = u;} } void getsz(int u, int fa) { sz[u] = 1; for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) { getsz(v, u); sz[u] += sz[v]; } } void dfs(int u, int fa, int dep) { ++bkt[dep]; mxd = max(mxd, dep); for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) dfs(v, u, dep + 1); } inline void mul(int *a, int len, int o) { len = Rev(len << 1); NTT(a, len, 0); for (int i = 0; i < len; ++i) a[i] = 1ll * a[i] * a[i] % mod; NTT(a, len, 1); if (o > 0) for (int i = 0; i < len; ++i) cnt[i + 1] += a[i]; else for (int i = 0; i < len; ++i) cnt[i + 3] -= a[i]; for (int i = 0; i < len; ++i) a[i] = 0; } inline void calc(int u, int o) { mxd = 0; dfs(u, 0, 0); mul(bkt, mxd, o); } void divide(int u) { vis[u] = 1; calc(u, 1); for (int i = hd[u], v; i; i = e[i].nxt) if (!vis[v = e[i].to]) { calc(v, -1); getsz(v, u); totn = mx = sz[v]; rt = v; getrt(v, 0); divide(rt); } } int main() { init(); n = rd(); for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1); mx = totn = n; getrt(1, 0); divide(rt); double ans = 0.0; for (int i = 1; i <= n + 1; ++i) ans += (double) cnt[i] / i; printf("%.4lf", ans); return 0; }
在點分治求路徑條數時,咱們嘗試用按秩合併的思路去搞,也就是將子樹按照最深深度排序,而後逐個合併計算答案。
開始的時候只有 \(bkt[0]=1\),而後按順序卷每個子樹求出來的計數數組 \(bktson\) 。
把貢獻直接計算,而後再將 \(bktson\) 按位加到 \(bkt\) 上。
考慮複雜度,將子樹按照深度從小到大排序後,每次卷積獲得的新的鏈長不會超過新合併的子樹深度的二倍,因此每次卷積的數組長度爲 \(mxdep[v]\) 的,且每一個位置只會和其父節點卷積一次,所以總複雜度爲 \(\mathcal O(n\log^2 n)\)
#include <cmath> #include <cstdio> #include <cctype> #include <cstdlib> #include <cstring> #include <iostream> #include <algorithm> #define N 65537 #define mod 998244353 using namespace std; typedef long long ll; inline int rd() { int x = 0; char c = getchar(); while (!isdigit(c)) c = getchar(); while (isdigit(c)) { x = x * 10 + (c ^ 48); c = getchar(); } return x; } inline void print(ll x) { int y = 10, len = 1; while(y <= x) {y *= 10; ++len;} while(len--) {y /= 10; putchar(x / y + 48); x %= y;} putchar('\n'); } inline int fpow(int x, int t = mod - 2) { int res = 1; while (t) { if (t & 1) res = 1ll * res * x % mod; x = 1ll * x * x % mod; t >>= 1; } return res; } int mxlen = (1 << 16), w[2][N], rev[N]; inline int mo(int x) { return x >= mod ? x - mod : x; } inline void init() { int per = fpow(3, (mod - 1) / mxlen); int invper = fpow(per); w[0][0] = w[1][0] = 1; for (int i = 1; i < mxlen; ++i) { w[0][i] = 1ll * w[0][i - 1] * per % mod; w[1][i] = 1ll * w[1][i - 1] * invper % mod; } } inline int Rev(int n) { int len = 1, bit = 0; while (len <= n) len <<= 1, ++bit; for (int i = 0; i < len; ++i) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1))); return len; } inline void NTT(int *f, int len, int o) { for (int i = 0; i < len; ++i) if (i > rev[i]) swap(f[i], f[rev[i]]); for (int i = 1; i < len; i <<= 1) { int wn = mxlen / (i << 1); for (int j = 0; j < len; j += (i << 1)) { int nw = 0, x, y; for (int k = 0; k < i; ++k, nw += wn) { x = f[j + k]; y = 1ll * w[o][nw] * f[i + j + k] % mod; f[j + k] = mo(x + y); f[i + j + k] = mo(x - y + mod); } } } if (o == 1) { int invl = fpow(len); for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod; } } bool vis[N]; double ans = 0.0; int n, m, tot, totn, mx, rt; int bkt[N], sz[N], hd[N]; struct edge{int to, nxt;} e[N << 1]; inline void add(int u, int v) { e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot; e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot; } void getrt(int u, int fa) { sz[u] = 1; int mxs = 0; for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) { getrt(v, u); sz[u] += sz[v]; mxs = max(mxs, sz[v]); } mxs = max(mxs, totn - sz[u]); if (mxs < mx) {mx = mxs; rt = u;} } void getsz(int u, int fa) { sz[u] = 1; for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) { getsz(v, u); sz[u] += sz[v]; } } int res[N], tmp[N]; inline int mul(int *a, int *b, int lena, int lenb) { int len = Rev(lenb << 1); for (int i = 0; i < lena; ++i) res[i] = a[i]; for (int i = lena; i < len; ++i) res[i] = 0; for (int i = 0; i < lenb; ++i) tmp[i] = b[i]; for (int i = lenb; i < len; ++i) tmp[i] = 0; NTT(res, len, 0); NTT(tmp, len, 0); for (int i = 0; i < len; ++i) res[i] = 1ll * res[i] * tmp[i] % mod; NTT(res, len, 1); for (int i = 0; i < len; ++i) ans += 2.0 * res[i] / (i + 1); return len; } int mxd[N], s[N], bkts[N]; inline bool cmp(int x, int y) {return mxd[x] < mxd[y];} int dfs(int u, int fa, int dep) { int resd = dep; for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) resd = max(resd, dfs(v, u, dep + 1)); return resd; } void dfs2(int u, int fa, int dep) { ++bkts[dep]; for (int i = hd[u], v; i; i = e[i].nxt) if ((v = e[i].to) != fa && !vis[v]) dfs2(v, u, dep + 1); } void divide(int u) { vis[u] = 1; s[0] = 0; for (int i = hd[u], v; i; i = e[i].nxt) if (!vis[v = e[i].to]) { s[++s[0]] = v; mxd[v] = dfs(v, u, 1); } sort(s + 1, s + 1 + s[0], cmp); bkt[0] = 1; int nowlen = 1; for (int i = 1, v; i <= s[0]; ++i) { dfs2(v = s[i], 0, 1); nowlen = mul(bkt, bkts, nowlen, mxd[v] + 1); for (int i = 0; i <= mxd[v]; ++i) { bkt[i] += bkts[i]; bkts[i] = 0; } } for (int i = 0; i <= nowlen; ++i) bkt[i] = 0; for (int i = hd[u], v; i; i = e[i].nxt) if (!vis[v = e[i].to]) { getsz(v, u); totn = mx = sz[v]; rt = v; getrt(v, 0); divide(rt); } } int main() { init(); n = rd(); for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1); mx = totn = n; getrt(1, 0); divide(rt); printf("%.4lf", ans + n); return 0; }