題目來源:NOI2019模擬測試賽(七)ios
非原題面,題意有略微區別ide
心態崩了。測試
好不容易場上想出一題正解,寫了三個小時結果寫了個假的點分治,卡成$O(n^2)$spa
我退役吧。.net
原題是求隨機樹分治的指望深度和,題意相同。code
對於一個點$x$,考慮點$y$是否能做爲它在點分樹上的祖先節點,顯然當且僅當$y$在$x$到$y$的路徑中第一個被選爲分治中心時會對$x$產生1的貢獻;blog
因爲路徑上全部點被選到的機率都是相等的,因此此時的指望就是$\frac{1}{dis(x,y)}$;get
那麼總的指望就是$\sum\limits_{x=1}^{n}\sum\limits_{y=1}^{n}\frac{1}{dis(x,y)}$;博客
在這裏寫個暴力便可爆踩個人假點分治;string
考慮統計每種長度的路徑條數,能夠用點分治作,而且在點分樹裏合併時子樹的指望是一個卷積的形式,所以能夠用FFT來加速;
因而我就快樂的寫了個點分治+FFT,得到了60分的好成績;
爲何?參考這篇博客的證實,我最初的寫法就是其中的第一種寫法,搜完一個子樹就和已經搜過的合併,這樣作的話FFT的長度會是$子樹中最大深度\times 根節點兒子個數=O(n^2)$的,正確的寫法應該搜完再一塊兒合併,或者像裏面說的第二種方法同樣直接搜當前子樹,更新答案而後搜重心的每一個兒子的子樹,減去不合法的路徑,這樣子FFT的長度纔是$O(n)$的。
1 #include<algorithm>
2 #include<iostream>
3 #include<cstring>
4 #include<cstdio>
5 #include<cmath>
6 #include<queue>
7 #define inf 2147483647
8 #define eps 1e-9
9 #define mod 1000000007
10 using namespace std; 11 typedef long long ll; 12 typedef double db; 13 const db pi=acos(-1.0); 14
15 struct edge{ 16 int v,next; 17 }a[200001]; 18 int n,u,v,S,rt,mxd,bit,bitnum,tot=0,cnt=0,ans=0,jc[100001],inv[100001],anss[200001],tp[200001],num[200001],s[200001],rev[200001],head[100001],mx[100001],siz[100001],dep[100001]; 19 bool used[100001]; 20 struct cp{ 21 db a,b; 22 cp(){} 23 cp(db _a,db _b){ 24 a=_a,b=_b; 25 } 26 friend cp operator +(cp a,cp b){return cp(a.a+b.a,a.b+b.b);} 27 friend cp operator -(cp a,cp b){return cp(a.a-b.a,a.b-b.b);} 28 friend cp operator *(cp a,cp b){return cp(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);} 29 friend cp operator *(cp a,db b){return cp(a.a*b,a.b*b);} 30 friend cp operator /(cp a,db b){return cp(a.a/b,a.b/b);} 31 }A[200001],B[200001],W[200001][2]; 32 void _(){ 33 for(int i=1;i<=(1<<17);i<<=1){ 34 W[i][0]=cp(cos(pi/i),sin(pi/i)); 35 W[i][1]=cp(cos(pi/i),-sin(pi/i)); 36 } 37 } 38 void fft(cp *s,int op){ 39 for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]); 40 for(int i=1;i<bit;i<<=1){ 41 //cp w(cos(pi/i),op*sin(pi/i));
42 cp w=W[i][op==-1]; 43 for(int p=i<<1,j=0;j<bit;j+=p){ 44 cp wk(1,0); 45 for(int k=j;k<i+j;k++,wk=wk*w){ 46 cp x=s[k],y=wk*s[k+i]; 47 s[k]=x+y; 48 s[k+i]=x-y; 49 } 50 } 51 } 52 if(op==-1){ 53 for(int i=0;i<bit;i++){ 54 s[i]=s[i]/(db)bit; 55 } 56 } 57 } 58 void add(int u,int v){ 59 a[++tot].v=v; 60 a[tot].next=head[u]; 61 head[u]=tot; 62 } 63 void mul(int *ret,int *a,int *b,int n){ 64 for(bit=1,bitnum=0;bit<=n*2;bit<<=1)bitnum++; 65 for(int i=1;i<=bit;i++){ 66 rev[i]=(rev[i>>1]>>1)|((i&1)<<(bitnum-1)); 67 } 68 for(int i=0;i<bit;i++){ 69 A[i]=cp((db)a[i],0); 70 B[i]=cp(0,0); 71 } 72 for(int i=1;i<=cnt;i++){ 73 a[b[i]]++; 74 B[b[i]].a+=1; 75 } 76 fft(A,1); 77 fft(B,1); 78 for(int i=0;i<bit;i++)A[i]=A[i]*B[i]; 79 fft(A,-1); 80 for(int i=0;i<bit;i++)ret[i]=(int)(A[i].a+0.5); 81 } 82 void getrt(int u,int fa){ 83 mx[u]=0; 84 siz[u]=1; 85 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 86 int v=a[tmp].v; 87 if(!used[v]&&v!=fa){ 88 getrt(v,u); 89 siz[u]+=siz[v]; 90 mx[u]=max(mx[u],siz[v]); 91 } 92 } 93 mx[u]=max(mx[u],S-mx[u]); 94 if(mx[u]<mx[rt])rt=u; 95 } 96 void getdep(int u,int fa,int dpt){ 97 mxd=max(mxd,dpt); 98 s[++cnt]=dpt; 99 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 100 int v=a[tmp].v; 101 if(!used[v]&&v!=fa){ 102 getdep(v,u,dpt+1); 103 } 104 } 105 } 106 void divide(int u){ 107 used[u]=true; 108 mxd=0; 109 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 110 int v=a[tmp].v; 111 if(!used[v]){ 112 cnt=0; 113 getdep(v,u,1); 114 mul(tp,num,s,mxd); 115 for(int i=0;i<bit;i++)anss[i]+=tp[i]; 116 } 117 } 118 for(int i=1;i<=mxd;i++){ 119 anss[i]+=num[i]; 120 num[i]=0; 121 } 122 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 123 int v=a[tmp].v; 124 if(!used[v]){ 125 S=siz[v]; 126 rt=0; 127 getrt(v,0); 128 divide(rt); 129 } 130 } 131 } 132 int main(){ 133 memset(head,-1,sizeof(head)); 134 _(); 135 scanf("%d",&n); 136 jc[0]=inv[0]=inv[1]=1; 137 for(int i=2;i<=n+1;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; 138 for(int i=1;i<=n+1;i++)jc[i]=(ll)jc[i-1]*i%mod; 139 for(int i=1;i<n;i++){ 140 scanf("%d%d",&u,&v); 141 add(u,v); 142 add(v,u); 143 } 144 S=n; 145 mx[rt=0]=6666666; 146 getrt(1,-1); 147 divide(rt); 148 ans=n; 149 for(int i=1;i<=n;i++){ 150 ans=(ans+(ll)anss[i]*inv[i+1]*2%mod)%mod; 151 } 152 printf("%lld",(ll)ans*jc[n]%mod); 153 return 0; 154 }
1 #include<algorithm>
2 #include<iostream>
3 #include<cstring>
4 #include<cstdio>
5 #include<cmath>
6 #include<queue>
7 #define inf 2147483647
8 #define eps 1e-9
9 #define mod 1000000007
10 using namespace std; 11 typedef long long ll; 12 typedef double db; 13 const db pi=acos(-1.0); 14
15 struct edge{ 16 int v,next; 17 }a[200001]; 18 int n,u,v,S,rt,mxd,bit,bitnum,tot=0,cnt=0,ans=0,jc[100001],inv[100001],anss[200001],tp[200001],num[200001],rev[200001],head[100001],mx[100001],siz[100001],dep[100001],dps[100001]; 19 bool used[100001]; 20 struct cp{ 21 db a,b; 22 cp(){} 23 cp(db _a,db _b){ 24 a=_a,b=_b; 25 } 26 friend cp operator +(cp a,cp b){return cp(a.a+b.a,a.b+b.b);} 27 friend cp operator -(cp a,cp b){return cp(a.a-b.a,a.b-b.b);} 28 friend cp operator *(cp a,cp b){return cp(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);} 29 friend cp operator *(cp a,db b){return cp(a.a*b,a.b*b);} 30 friend cp operator /(cp a,db b){return cp(a.a/b,a.b/b);} 31 }A[200001],B[200001],W[200001][2]; 32 void _(){ 33 for(int i=1;i<=(1<<17);i<<=1){ 34 W[i][0]=cp(cos(pi/i),sin(pi/i)); 35 W[i][1]=cp(cos(pi/i),-sin(pi/i)); 36 } 37 } 38 void fft(cp *s,int op){ 39 for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]); 40 for(int i=1;i<bit;i<<=1){ 41 //cp w(cos(pi/i),op*sin(pi/i));
42 cp w=W[i][op==-1]; 43 for(int p=i<<1,j=0;j<bit;j+=p){ 44 cp wk(1,0); 45 for(int k=j;k<i+j;k++,wk=wk*w){ 46 cp x=s[k],y=wk*s[k+i]; 47 s[k]=x+y; 48 s[k+i]=x-y; 49 } 50 } 51 } 52 if(op==-1){ 53 for(int i=0;i<bit;i++){ 54 s[i]=s[i]/(db)bit; 55 } 56 } 57 } 58 void add(int u,int v){ 59 a[++tot].v=v; 60 a[tot].next=head[u]; 61 head[u]=tot; 62 } 63 void mul(int *ret,int *a,int *b,int n){ 64 for(bit=1,bitnum=0;bit<=n*2;bit<<=1)bitnum++; 65 for(int i=1;i<bit;i++){ 66 rev[i]=(rev[i>>1]>>1)|((i&1)<<(bitnum-1)); 67 } 68 for(int i=0;i<bit;i++){ 69 A[i]=cp((db)a[i],0); 70 B[i]=cp((db)b[i],0); 71 } 72 fft(A,1); 73 fft(B,1); 74 for(int i=0;i<bit;i++)A[i]=A[i]*B[i]; 75 fft(A,-1); 76 for(int i=0;i<bit;i++)ret[i]=(int)(A[i].a+0.5); 77 } 78 void getrt(int u,int fa){ 79 mx[u]=0; 80 siz[u]=1; 81 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 82 int v=a[tmp].v; 83 if(!used[v]&&v!=fa){ 84 getrt(v,u); 85 siz[u]+=siz[v]; 86 mx[u]=max(mx[u],siz[v]); 87 } 88 } 89 mx[u]=max(mx[u],S-mx[u]); 90 if(mx[u]<mx[rt])rt=u; 91 } 92 void getdep(int u,int fa,int dpt){ 93 mxd=max(mxd,dpt); 94 dps[dpt]++; 95 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 96 int v=a[tmp].v; 97 if(!used[v]&&v!=fa){ 98 getdep(v,u,dpt+1); 99 } 100 } 101 } 102 void divide(int u){ 103 used[u]=true; 104 num[0]=1; 105 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 106 int v=a[tmp].v; 107 if(!used[v]){ 108 getdep(v,u,1); 109 for(int i=1;i<=mxd;i++){ 110 num[i]+=dps[i]; 111 tp[i]=dps[i]; 112 dps[i]=0; 113 } 114 cnt=max(cnt,mxd); 115 mul(tp,tp,tp,mxd); 116 for(int i=1;i<=mxd*2;i++){ 117 anss[i]-=tp[i]; 118 tp[i]=0; 119 } 120 mxd=0; 121 } 122 } 123 for(int i=0;i<=cnt;i++){ 124 tp[i]=num[i]; 125 num[i]=0; 126 } 127 mul(tp,tp,tp,cnt); 128 for(int i=0;i<=cnt*2;i++){ 129 anss[i]+=tp[i]; 130 tp[i]=0; 131 } 132 cnt=0; 133 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 134 int v=a[tmp].v; 135 if(!used[v]){ 136 S=siz[v]; 137 rt=0; 138 getrt(v,0); 139 divide(rt); 140 } 141 } 142 } 143 int main(){ 144 memset(head,-1,sizeof(head)); 145 _(); 146 scanf("%d",&n); 147 jc[0]=inv[0]=inv[1]=1; 148 for(int i=2;i<=n+1;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; 149 for(int i=1;i<=n+1;i++)jc[i]=(ll)jc[i-1]*i%mod; 150 for(int i=1;i<n;i++){ 151 scanf("%d%d",&u,&v); 152 add(u,v); 153 add(v,u); 154 } 155 S=n; 156 mx[rt=0]=6666666; 157 getrt(1,-1); 158 divide(rt); 159 ans=n; 160 for(int i=1;i<=n;i++){ 161 ans=(ans+(ll)anss[i]*inv[i+1]%mod)%mod; 162 } 163 printf("%lld",(ll)ans*jc[n]%mod); 164 return 0; 165 }