【HDU4661】Message Passing-思惟+樹形DP+組合數學

測試地址:Message Passing
題目大意: n n 我的,每一個人知道一條獨一無二的信息,每次能夠選擇一我的,向與他有關係的一我的傳遞全部他已經知道的信息,關係網是樹狀的,目標是讓全部人都知道全部的信息,問有多少種傳遞信息的方案,使得傳遞的次數最少。
作法: 本題須要用到思惟+樹形DP+組合數學。
首先,顯然傳遞次數的下限是 2 ( n 1 ) 2(n-1) ,那麼咱們能不能到達這個下限呢?若是能,怎麼到達?其實容易觀察出,咱們能夠先把全部信息都彙集在某一我的,而後再讓信息從這我的開始傳遞到全部人,這樣傳遞次數就能達到下界,能夠證實全部最優解都知足這樣的過程。
因而如今就是求以某我的 x x 爲中間的匯集點時,總的方案數是多少。咱們發現,若是以 x x 爲根,從兒子向父親連邊,那麼從各點傳遞信息到 x x 的方案數,就等於拓撲序的數量。同理,從父親向兒子連邊時拓撲序的數量,就等於從 x x 傳遞信息到各點的方案數。那麼根據乘法原理,整個過程的方案數只要把這兩個方案數乘起來便可。易知,一個圖和其反圖的拓撲序數量是相同的,因此咱們只要求從兒子向父親連邊時的方案數 a n s x ans_x 便可。
咱們先隨便選一個點爲根,假設是 1 1 ,那麼咱們能夠用樹形DP求出 a n s 1 ans_1 。具體來講,用 f ( i ) f(i) 表示以點 i i 爲根的子樹的方案數,那麼在合併方案時,點 i i 必定是最後選,而它的各子樹中又要知足各自的順序,因此方案數其實是: C s i z ( i ) 1 s i z ( s o n 1 ) C s i z ( i ) 1 s i z ( s o n 1 ) s i z ( s o n 2 ) . . . f ( s o n i ) C_{siz(i)-1}^{siz(son_1)}\cdot C_{siz(i)-1-siz(son_1)}^{siz(son_2)}\cdot...\cdot \prod f(son_i) ,其中 s i z ( x ) siz(x) 表示以 x x 爲根子樹中的點數,把組合數拆開簡化,這個式子能夠寫成:
f ( i ) = s i z ( i ) ! f ( s o n i ) s i z ( s o n i ) ! f(i)=siz(i)!\cdot \prod \frac{f(son_i)}{siz(son_i)!}
預處理階乘就能夠 O ( 1 ) O(1) 轉移了,這樣咱們就能 O ( n ) O(n) 地算出 a n s 1 ans_1 了。
但若是計算每一個 a n s x ans_x 都要 O ( n ) O(n) ,總的時間複雜度也是受不了的,所以咱們使用經典的方法——換根。仍是先以 1 1 爲整棵樹的根,令 f a ( x ) fa(x) x x 的父親,假設 a n s f a ( x ) ans_{fa(x)} 已經算出,如何計算 a n s x ans_x
要計算這個東西,首先當以 x x 爲根時,它在以 1 1 爲根的樹中的兒子如今依然是它的兒子,區別就是多了一棵子樹,這棵子樹的形態和在以 f a ( x ) fa(x) 爲根的樹中去掉以 x x 爲根的子樹相同。所以,咱們只要在 a n s f a ( x ) ans_{fa(x)} 中消掉以 x x 爲根子樹的貢獻,而後將這個值做爲正常子樹對 a n s x ans_x 轉移便可,由於要求某個 f ( i ) f(i) 的逆元,因此轉移是 O ( log M o d ) O(\log Mod) 的。這也是爲何上面的轉移式子要寫成那樣的緣由:更加容易看出應該消掉哪一個部分的貢獻。因而DP的複雜度就是 O ( n log M o d ) O(n\log Mod) 了,這個問題就解決了。
如下是本人代碼:php

#include <bits/stdc++.h> using namespace std; typedef long long ll; const ll mod=1000000007; int T,n,first[1000010],tot,fa[1000010]={0},siz[1000010]; ll fac[1000010],inv[1000010],invfac[1000010]; ll f[1000010],ans[1000010]; struct edge { int v,next; }e[2000010]; void insert(int a,int b) { e[++tot].v=b; e[tot].next=first[a]; first[a]=tot; } ll power(ll a,ll b) { ll s=1,ss=a; while(b) { if (b&1) s=s*ss%mod; ss=ss*ss%mod;b>>=1; } return s; } void dp1(int v) { siz[v]=f[v]=1; for(int i=first[v];i;i=e[i].next) if (e[i].v!=fa[v]) { fa[e[i].v]=v; dp1(e[i].v); siz[v]+=siz[e[i].v]; f[v]=f[v]*invfac[siz[e[i].v]]%mod*f[e[i].v]%mod; } f[v]=f[v]*fac[siz[v]-1]%mod; } void dp2(int v) { if (v!=1) { ll faf=ans[fa[v]]; int fasiz=n-siz[v]; faf=faf*fac[siz[v]]%mod*power(f[v],mod-2)%mod; faf=faf*invfac[n-1]%mod*fac[fasiz-1]%mod; ans[v]=f[v]*invfac[fasiz]%mod*faf%mod; ans[v]=ans[v]*invfac[siz[v]-1]%mod*fac[n-1]%mod; } else ans[v]=f[v]; for(int i=first[v];i;i=e[i].next) if (e[i].v!=fa[v]) dp2(e[i].v); } int main() { scanf("%d",&T); while(T--) { scanf("%d",&n); tot=0; for(int i=1;i<=n;i++) first[i]=0; fac[0]=fac[1]=inv[1]=invfac[0]=invfac[1]=1; for(ll i=2;i<=n;i++) { fac[i]=fac[i-1]*i%mod; inv[i]=(mod-mod/i)*inv[mod%i]%mod; invfac[i]=invfac[i-1]*inv[i]%mod; } for(int i=1;i<n;i++) { int a,b; scanf("%d%d",&a,&b); insert(a,b),insert(b,a); } dp1(1); dp2(1); ll totans=0; for(int i=1;i<=n;i++) totans=(totans+ans[i]*ans[i])%mod; printf("%lld\n",totans); } return 0; } #include <bits/stdc++.h> using namespace std; typedef long long ll; const ll mod=1000000007; int T,n,first[1000010],tot,fa[1000010]={0},siz[1000010]; ll fac[1000010],inv[1000010],invfac[1000010]; ll f[1000010],ans[1000010]; struct edge { int v,next; }e[2000010]; void insert(int a,int b) { e[++tot].v=b; e[tot].next=first[a]; first[a]=tot; } ll power(ll a,ll b) { ll s=1,ss=a; while(b) { if (b&1) s=s*ss%mod; ss=ss*ss%mod;b>>=1; } return s; } void dp1(int v) { siz[v]=f[v]=1; for(int i=first[v];i;i=e[i].next) if (e[i].v!=fa[v]) { fa[e[i].v]=v; dp1(e[i].v); siz[v]+=siz[e[i].v]; f[v]=f[v]*invfac[siz[e[i].v]]%mod*f[e[i].v]%mod; } f[v]=f[v]*fac[siz[v]-1]%mod; } void dp2(int v) { if (v!=1) { ll faf=ans[fa[v]]; int fasiz=n-siz[v]; faf=faf*fac[siz[v]]%mod*power(f[v],mod-2)%mod; faf=faf*invfac[n-1]%mod*fac[fasiz-1]%mod; ans[v]=f[v]*invfac[fasiz]%mod*faf%mod; ans[v]=ans[v]*invfac[siz[v]-1]%mod*fac[n-1]%mod; } else ans[v]=f[v]; for(int i=first[v];i;i=e[i].next) if (e[i].v!=fa[v]) dp2(e[i].v); } int main() { scanf("%d",&T); while(T--) { scanf("%d",&n); tot=0; for(int i=1;i<=n;i++) first[i]=0; fac[0]=fac[1]=inv[1]=invfac[0]=invfac[1]=1; for(ll i=2;i<=n;i++) { fac[i]=fac[i-1]*i%mod; inv[i]=(mod-mod/i)*inv[mod%i]%mod; invfac[i]=invfac[i-1]*inv[i]%mod; } for(int i=1;i<n;i++) { int a,b; scanf("%d%d",&a,&b); insert(a,b),insert(b,a); } dp1(1); dp2(1); ll totans=0; for(int i=1;i<=n;i++) totans=(totans+ans[i]*ans[i])%mod; printf("%lld\n",totans); } return 0; }
相關文章
相關標籤/搜索