有兩棵 \(n\) 個點的樹 \(T_1\) 和 \(T_2\)。spa
你要給每一個點一個權值嗎,要求每一個點的權值爲 \([1,y]\) 內的整數。code
對於一條同時出如今兩棵樹上的邊,這條邊的兩個端點的值相同。get
若 \(op=0\),則給你兩棵樹 \(T_1,T_2\),求方案數。it
若 \(op=1\),則給你一棵樹 \(T_1\),求對於全部 \(n^{n-2}\) 種 \(T_2\),方案數之和。io
若 \(op=2\),則求對於全部的 \(T_1,T_2\),求方案數之和。class
\(n\leq 100000\)map
新建一個圖 \(G\),把兩棵樹的公共邊加到 \(G\) 中。記 \(m\) 爲兩棵樹的公共邊數量。那麼答案就是 \(y^{n-m}\)。im
令 \(z=y^{-1}\),那麼答案就變成了 \(y^nz^m\)。也就是說,每有一條相同的邊,方案的貢獻就要 \(\times z\)。static
這個你們都會。di
\[ z^m=\sum_{i=0}^m\binom{m}{i}(z-1)^i \]
那麼能夠枚舉一個邊集 \(E\),計算有多少種生成樹包含 \(E\),而後把答案加上方案數 \(\times{(z-1)}^{\lvert E\rvert}\)。
記這 \(E\) 條邊造成了 \(m\) 個連通塊,這些連通塊的大小爲 \(a_1,a_2,\ldots,a_m\),那麼貢獻就是
\[ \begin{align} &{(z-1)}^{n-m}\sum_{\sum_{i=1}^md_i=2m-2}(m-2)!\prod_{i=1}^m\frac{a_i^{d_i}}{(d_i-1)!}\\ =&{(z-1)}^{n-m}n^{m-2}\prod_{i=1}^ma_i\\ \end{align} \]
\(\prod_{i=1}^ma_i\) 能夠當作是每一個連通塊內選一個點的方案數。這樣就能夠DP了。
時間複雜度:\(O(n)\)
枚舉兩棵樹的公共邊個數:
\[ \begin{align} s_n&=\sum_{i=1}^{n}{(z-1)}^{n-i}\sum_{\sum_{j=1}^ia_j=n}\frac{n!}{i!}(\prod_{j=1}^i\frac{a_j^{a_j-2}}{a_j!})(n^{i-2}\prod_{j=1}^ia_j)^2\\ &=\sum_{i=1}^{n}{(z-1)}^{n-i}\frac{n!n^{2i-4}}{i!}\sum_{\sum_{j=1}^ia_j=n}\prod_{j=1}^i\frac{a_j^{a_j}}{a_j!}\\ &=\sum_{i=1}^{n}{(z-1)}^{n-i}n^{2i-4}\sum_{\sum_{j=1}^ia_j=n}\prod_{j=1}^i\binom{(\sum_{k=1}^ja_k)-1}{a_j-1}{}a_j^{a_j}\\ \end{align} \]
記 \(f_l=\sum_{i=1}^{l}{(z-1)}^{-i}n^{2i}\sum_{\sum_{j=1}^ia_j=l}\prod_{j=1}^i\binom{(\sum_{k=1}^ja_k)-1}{a_j-1}{}a_j^{a_j}\)。
轉移時枚舉最後一塊的大小,有:
\[ f_i=\begin{cases} 1&,i=0\\ \sum_{j=1}^i\frac{(i-1)!n^2j^jf_{i-j}}{(i-j)!(j-1)!(z-1)}&,i>0 \end{cases} \]
直接DP是 \(O(n^2)\) 的。
記 \(g_i=\sum_{i\geq 1}\frac{n^2i^i}{(i-1)!(z-1)}\),\(F(x)\) 爲 \(f\) 的 EGF,\(G(x)\) 爲 \(g\) 的 OGF,那麼
\[ \begin{align} xF'(x)&=F(x)G(x)\\ \frac{F'(x)}{F(x)}&=\frac{G(x)}{x}\\ \ln F(x)&=\int \frac{G(x)}{x}\\ F(x)&=e^{\int \frac{G(x)}{x}} \end{align} \]
直接多項式 exp 就行了。
答案爲 \((z-1)^nn^{-4}f_n\)
時間複雜度:\(O(n\log n)\)
const ll p=998244353; ll fp(ll a,ll b) { ll s=1; for(;b;b>>=1,a=a*a%p) if(b&1) s=s*a%p; return s; } const int N=100010; int n,op; ll z,_z; ll ans; namespace solve0 { map<int,int> a[N]; void solve() { if(_z==1) { ans=1; return; } int x,y; for(int i=1;i<n;i++) { io::get(x); io::get(y); if(x>y) swap(x,y); a[x][y]++; } ans=1; for(int i=1;i<n;i++) { io::get(x); io::get(y); if(x>y) swap(x,y); if(a[x].count(y)) ans=ans*z%p; } } } namespace solve1 { vector<int> g[N]; ll f[N][2]; void dfs(int x,int fa) { f[x][0]=f[x][1]=1; for(auto v:g[x]) if(v!=fa) { dfs(v,x); ll s0=(f[x][0]*f[v][0]%p*z+f[x][0]*f[v][1]%p*n)%p; ll s1=(f[x][0]*f[v][1]%p*z+f[x][1]*f[v][0]%p*z+f[x][1]*f[v][1]%p*n)%p; f[x][0]=s0; f[x][1]=s1; } } void solve() { if(_z==1) { ans=fp(n,n-2); return; } int x,y; for(int i=1;i<n;i++) { io::get(x); io::get(y); g[x].push_back(y); g[y].push_back(x); } z--; dfs(1,0); ans=f[1][1]*fp(n,p-2)%p; } } namespace solve2 { const int N=270000; namespace ntt { const int W=262144; ll w[N]; int rev[N]; void init() { w[0]=1; ll s=fp(3,(p-1)/W); for(int i=1;i<W/2;i++) w[i]=w[i-1]*s%p; } void ntt(ll *a,int n,int t) { for(int i=1;i<n;i++) { rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0); if(rev[i]>i) swap(a[i],a[rev[i]]); } for(int i=2;i<=n;i<<=1) for(int j=0;j<n;j+=i) for(int k=0;k<i/2;k++) { ll u=a[j+k]; ll v=a[j+k+i/2]*w[W/i*k]; a[j+k]=(u+v)%p; a[j+k+i/2]=(u-v)%p; } if(t==-1) { reverse(a+1,a+n); ll inv=fp(n,p-2); for(int i=0;i<n;i++) a[i]=a[i]*inv%p; } } void mul(ll *a,ll *b,ll *c,int n,int m,int l) { static ll a1[N],a2[N]; int k=1; while(k<=n+m) k<<=1; memset(a1,0,sizeof(a1[0])*k); memset(a2,0,sizeof(a2[0])*k); memcpy(a1,a,sizeof(a1[0])*(n+1)); memcpy(a2,b,sizeof(a2[0])*(m+1)); ntt::ntt(a1,k,1); ntt::ntt(a2,k,1); for(int i=0;i<k;i++) a1[i]=a1[i]*a2[i]%p; ntt::ntt(a1,k,-1); memcpy(c,a1,sizeof(a1[0])*(l+1)); } void inv(ll *a,ll *b,int n) { if(n==1) { b[0]=fp(a[0],p-2); return; } inv(a,b,n>>1); static ll a1[N],a2[N]; memset(a1,0,sizeof(a1[0])*(n<<1)); memset(a2,0,sizeof(a2[0])*(n<<1)); memcpy(a1,a,sizeof(a1[0])*n); memcpy(a2,b,sizeof(a2[0])*(n>>1)); ntt(a1,n<<1,1); ntt(a2,n<<1,1); for(int i=0;i<n<<1;i++) a1[i]=a2[i]*(2-a1[i]*a2[i]%p)%p; ntt(a1,n<<1,-1); memcpy(b,a1,sizeof(a1[0])*n); } void ln(ll *a,ll *b,int n) { static ll a1[N],a2[N],a3[N]; for(int i=1;i<n;i++) a1[i-1]=a[i]*i%p; a1[n-1]=0; inv(a,a2,n); mul(a1,a2,a3,n-1,n-1,n-1); for(int i=1;i<n;i++) b[i]=a3[i-1]*fp(i,p-2)%p; b[0]=0; } void exp(ll *a,ll *b,int n) { if(n==1) { b[0]=1; return; } exp(a,b,n>>1); static ll a1[N],a2[N],a3[N]; memset(b+(n>>1),0,sizeof(b[0])*(n>>1)); ln(b,a3,n); memset(a1,0,sizeof(a1[0])*n); memset(a2,0,sizeof(a2[0])*n); memcpy(a1,b,sizeof(a1[0])*(n>>1)); for(int i=0;i<(n>>1);i++) a2[i]=a[(n>>1)+i]-a3[(n>>1)+i]; ntt(a1,n,1); ntt(a2,n,1); for(int i=0;i<n;i++) a1[i]=a1[i]*a2[i]%p; ntt(a1,n,-1); memcpy(b+(n>>1),a1,sizeof(a1[0])*(n>>1)); } } ll inv[N],fac[N],ifac[N]; ll f[N],g[N],w[N]; void solve() { if(_z==1) { ans=fp(n,n-2)*fp(n,n-2)%p; return; } z--; ntt::init(); fac[0]=fac[1]=ifac[0]=ifac[1]=inv[1]=1; for(int i=2;i<=n;i++) { fac[i]=fac[i-1]*i%p; inv[i]=-p/i*inv[p%i]%p; ifac[i]=ifac[i-1]*inv[i]%p; } ll ifacz=fp(z,p-2); // f[0]=1; // for(int i=1;i<=n;i++) // w[i]=fp(i,i); // for(int i=1;i<=n;i++) // for(int j=1;j<=i;j++) // f[i]=(f[i]+f[i-j]*fac[i-1]%p*ifac[i-j]%p*ifac[j-1]%p*n%p*n%p*w[j]%p*ifacz)%p; for(int i=1;i<=n;i++) g[i]=fp(i,i)*n%p*n%p*ifac[i-1]%p*ifacz%p*inv[i]%p; int k=1; while(k<=n) k<<=1; ntt::exp(g,f,k); ans=f[n]*fac[n]%p*fp(z,n)%p*fp(n,p-1-4)%p; } } int main() { freopen("tree.in","r",stdin); freopen("tree.out","w",stdout); io::get(n); io::get(_z); io::get(op); z=fp(_z,p-2); if(op==0) solve0::solve(); else if(op==1) solve1::solve(); else solve2::solve(); ans=ans*fp(_z,n)%p; ans=(ans%p+p)%p; io::put(ans); return 0; }