首先能夠把題目轉化一下:把樹拆成若干條鏈,每條鏈的顏色爲其所在的樹的顏色,而後排放全部的鏈成環,求使得相鄰位置顏色不一樣的排列方案數。c++
而後本題分爲兩個部分:將一棵樹分爲1~n條不相交的鏈的方案數;將這些鏈安排順序使得不存在兩條相鄰的鏈來自同一棵樹。ide
第一部分顯然能夠O(n2)樹形DP,f[i][j][0/1/2]表示i及其子樹j條鏈,i向兒子連出0/1/2條邊的方案數,而後直接揹包DP便可。看似O(n3)的樹形揹包DP實際上是O(n2)的。證實複雜度:其實DP時只循環到sz[u]/sz[v]便可,而後能夠把每一個轉移視爲兒子v內子樹的每一個節點和節點u內v外節點組成的點對,因而所有DP完就是枚舉了全部的點對,複雜度顯然O(n2)。函數
第二部分,考慮n個點的樹劃分紅i條鏈的方案是f[i],若是不考慮環只考慮鏈其對應的指數生成函數爲Σf[i]i!(Σ(-1)i-jC(i-1,i-j)xj/j!),其中i∈[1,n],j∈[1,i]。拓展到環上,欽定一棵樹做爲開頭,若是該顏色有i條鏈,則被算了i次,而後其指數生成函數爲:Σf[i](i-1)!(Σ(-1)i-jC(i-1,i-j)xj-1/(j-1)!),其中i∈[1,n],j∈[1,i]。減去首尾同色後,生成函數是這樣的:Σf[i](i-1)!(Σ(-1)i-jC(i-1,i-j)xj-2/(j-2)!),其中i∈[2,n],j∈[2,i]。而後暴力卷積便可。spa
![](http://static.javashuo.com/static/loading.gif)
![](http://static.javashuo.com/static/loading.gif)
#include<bits/stdc++.h> using namespace std; const int N=5005,mod=998244353; int n,m,sum,ans,fac[N],inv[N],sz[N],f[N][N][3],g[N],tmp[N][3],dp[310][N],b[N]; vector<int>G[N]; int qpow(int a,int b) { int ret=1; while(b) { if(b&1)ret=1ll*ret*a%mod; a=1ll*a*a%mod,b>>=1; } return ret; } void dfs(int u,int fa) { sz[u]=1,f[u][1][0]=1; for(int i=0;i<G[u].size();i++) if(G[u][i]!=fa) { int v=G[u][i]; dfs(v,u); for(int j=0;j<=sz[u]+sz[v];j++)tmp[j][0]=tmp[j][1]=tmp[j][2]=0; for(int j=1;j<=sz[u];j++) for(int k=1;k<=sz[v];k++) { tmp[j+k][0]=(tmp[j+k][0]+1ll*f[u][j][0]*(f[v][k][0]+2ll*f[v][k][1]+2ll*f[v][k][2]))%mod; tmp[j+k-1][1]=(tmp[j+k-1][1]+1ll*f[u][j][0]*(f[v][k][0]+f[v][k][1]))%mod; tmp[j+k][1]=(tmp[j+k][1]+1ll*f[u][j][1]*(f[v][k][0]+2ll*f[v][k][1]+2ll*f[v][k][2]))%mod; tmp[j+k-1][2]=(tmp[j+k-1][2]+1ll*f[u][j][1]*(f[v][k][0]+f[v][k][1]))%mod; tmp[j+k][2]=(tmp[j+k][2]+1ll*f[u][j][2]*(f[v][k][0]+2ll*f[v][k][1]+2ll*f[v][k][2]))%mod; } sz[u]+=sz[v]; for(int j=1;j<=sz[u];j++)f[u][j][0]=tmp[j][0],f[u][j][1]=tmp[j][1],f[u][j][2]=tmp[j][2]; } } int C(int a,int b){return a<b?0:1ll*fac[a]*inv[b]%mod*inv[a-b]%mod;} int S(int a,int b){return (!a&&!b)?1:1ll*fac[a]*C(a-1,a-b)%mod;} int main() { fac[0]=1;for(int i=1;i<=5000;i++)fac[i]=1ll*fac[i-1]*i%mod; for(int i=0;i<=5000;i++)inv[i]=qpow(fac[i],mod-2); scanf("%d",&m); dp[0][0]=1; for(int p=1;p<=m;p++) { scanf("%d",&n); for(int i=1;i<=n;i++)G[i].clear(); for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),G[x].push_back(y),G[y].push_back(x); for(int i=1;i<=n;i++) for(int j=1;j<=n;j++) f[i][j][0]=f[i][j][1]=f[i][j][2]=0; dfs(1,0); memset(g,0,sizeof g); for(int i=1;i<=n;i++)g[i]=(f[1][i][0]+2ll*f[1][i][1]+2ll*f[1][i][2])%mod; if(p!=m) { memset(b,0,sizeof b); for(int j=1;j<=n;j++) if(g[j])for(int k=0,t=1;k<=j;k++,t=mod-t) b[j-k]=(b[j-k]+1ll*t*S(j,j-k)%mod*g[j])%mod; for(int i=0;i<=sum;i++) if(dp[p-1][i])for(int j=0;j<=n;j++) dp[p][i+j]=(dp[p][i+j]+1ll*C(i+j,j)*b[j]%mod*dp[p-1][i])%mod; } else{ memset(b,0,sizeof b); for(int j=1;j<=n;j++) if(g[j])for(int k=0,t=1;k<j;k++,t=mod-t) b[j-1-k]=(b[j-1-k]+1ll*t*S(j-1,j-k-1)%mod*g[j])%mod; for(int i=0;i<=sum;i++) if(dp[p-1][i])for(int j=0;j<=n;j++) ans=(ans+1ll*C(i-2+j,j)*b[j]%mod*dp[p-1][i])%mod; } sum+=n; } printf("%d",ans); }