【BZOJ3451】Tyvj1953 Normal - 點分治+FFT

題目來源:NOI2019模擬測試賽(七)ios

非原題面,題意有略微區別ide

題意:

吐槽:

心態崩了。測試

好不容易場上想出一題正解,寫了三個小時結果寫了個假的點分治,卡成$O(n^2)$spa

我退役吧。.net

題解:

原題是求隨機樹分治的指望深度和,題意相同。code

對於一個點$x$,考慮點$y$是否能做爲它在點分樹上的祖先節點,顯然當且僅當$y$在$x$到$y$的路徑中第一個被選爲分治中心時會對$x$產生1的貢獻;blog

因爲路徑上全部點被選到的機率都是相等的,因此此時的指望就是$\frac{1}{dis(x,y)}$;get

那麼總的指望就是$\sum\limits_{x=1}^{n}\sum\limits_{y=1}^{n}\frac{1}{dis(x,y)}$;博客

在這裏寫個暴力便可爆踩個人假點分治;string

考慮統計每種長度的路徑條數,能夠用點分治作,而且在點分樹裏合併時子樹的指望是一個卷積的形式,所以能夠用FFT來加速;

因而我就快樂的寫了個點分治+FFT,得到了60分的好成績;

爲何?參考這篇博客的證實,我最初的寫法就是其中的第一種寫法,搜完一個子樹就和已經搜過的合併,這樣作的話FFT的長度會是$子樹中最大深度\times 根節點兒子個數=O(n^2)$的,正確的寫法應該搜完再一塊兒合併,或者像裏面說的第二種方法同樣直接搜當前子樹,更新答案而後搜重心的每一個兒子的子樹,減去不合法的路徑,這樣子FFT的長度纔是$O(n)$的。

代碼:

假點分治(60pts):

 1 #include<algorithm>
 2 #include<iostream>
 3 #include<cstring>
 4 #include<cstdio>
 5 #include<cmath>
 6 #include<queue>
 7 #define inf 2147483647
 8 #define eps 1e-9
 9 #define mod 1000000007
 10 using namespace std;  11 typedef long long ll;  12 typedef double db;  13 const db pi=acos(-1.0);  14 
 15 struct edge{  16     int v,next;  17 }a[200001];  18 int n,u,v,S,rt,mxd,bit,bitnum,tot=0,cnt=0,ans=0,jc[100001],inv[100001],anss[200001],tp[200001],num[200001],s[200001],rev[200001],head[100001],mx[100001],siz[100001],dep[100001];  19 bool used[100001];  20 struct cp{  21  db a,b;  22  cp(){}  23  cp(db _a,db _b){  24         a=_a,b=_b;  25  }  26     friend cp operator +(cp a,cp b){return cp(a.a+b.a,a.b+b.b);}  27     friend cp operator -(cp a,cp b){return cp(a.a-b.a,a.b-b.b);}  28     friend cp operator *(cp a,cp b){return cp(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);}  29     friend cp operator *(cp a,db b){return cp(a.a*b,a.b*b);}  30     friend cp operator /(cp a,db b){return cp(a.a/b,a.b/b);}  31 }A[200001],B[200001],W[200001][2];  32 void _(){  33     for(int i=1;i<=(1<<17);i<<=1){  34         W[i][0]=cp(cos(pi/i),sin(pi/i));  35         W[i][1]=cp(cos(pi/i),-sin(pi/i));  36  }  37 }  38 void fft(cp *s,int op){  39     for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]);  40     for(int i=1;i<bit;i<<=1){  41         //cp w(cos(pi/i),op*sin(pi/i));
 42         cp w=W[i][op==-1];  43         for(int p=i<<1,j=0;j<bit;j+=p){  44             cp wk(1,0);  45             for(int k=j;k<i+j;k++,wk=wk*w){  46                 cp x=s[k],y=wk*s[k+i];  47                 s[k]=x+y;  48                 s[k+i]=x-y;  49  }  50  }  51  }  52     if(op==-1){  53         for(int i=0;i<bit;i++){  54             s[i]=s[i]/(db)bit;  55  }  56  }  57 }  58 void add(int u,int v){  59     a[++tot].v=v;  60     a[tot].next=head[u];  61     head[u]=tot;  62 }  63 void mul(int *ret,int *a,int *b,int n){  64     for(bit=1,bitnum=0;bit<=n*2;bit<<=1)bitnum++;  65     for(int i=1;i<=bit;i++){  66         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bitnum-1));  67  }  68     for(int i=0;i<bit;i++){  69         A[i]=cp((db)a[i],0);  70         B[i]=cp(0,0);  71  }  72     for(int i=1;i<=cnt;i++){  73         a[b[i]]++;  74         B[b[i]].a+=1;  75  }  76     fft(A,1);  77     fft(B,1);  78     for(int i=0;i<bit;i++)A[i]=A[i]*B[i];  79     fft(A,-1);  80     for(int i=0;i<bit;i++)ret[i]=(int)(A[i].a+0.5);  81 }  82 void getrt(int u,int fa){  83     mx[u]=0;  84     siz[u]=1;  85     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){  86         int v=a[tmp].v;  87         if(!used[v]&&v!=fa){  88  getrt(v,u);  89             siz[u]+=siz[v];  90             mx[u]=max(mx[u],siz[v]);  91  }  92  }  93     mx[u]=max(mx[u],S-mx[u]);  94     if(mx[u]<mx[rt])rt=u;  95 }  96 void getdep(int u,int fa,int dpt){  97     mxd=max(mxd,dpt);  98     s[++cnt]=dpt;  99     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 100         int v=a[tmp].v; 101         if(!used[v]&&v!=fa){ 102             getdep(v,u,dpt+1); 103  } 104  } 105 } 106 void divide(int u){ 107     used[u]=true; 108     mxd=0; 109     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 110         int v=a[tmp].v; 111         if(!used[v]){ 112             cnt=0; 113             getdep(v,u,1); 114  mul(tp,num,s,mxd); 115             for(int i=0;i<bit;i++)anss[i]+=tp[i]; 116  } 117  } 118     for(int i=1;i<=mxd;i++){ 119         anss[i]+=num[i]; 120         num[i]=0; 121  } 122     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 123         int v=a[tmp].v; 124         if(!used[v]){ 125             S=siz[v]; 126             rt=0; 127             getrt(v,0); 128  divide(rt); 129  } 130  } 131 } 132 int main(){ 133     memset(head,-1,sizeof(head)); 134  _(); 135     scanf("%d",&n); 136     jc[0]=inv[0]=inv[1]=1; 137     for(int i=2;i<=n+1;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; 138     for(int i=1;i<=n+1;i++)jc[i]=(ll)jc[i-1]*i%mod; 139     for(int i=1;i<n;i++){ 140         scanf("%d%d",&u,&v); 141  add(u,v); 142  add(v,u); 143  } 144     S=n; 145     mx[rt=0]=6666666; 146     getrt(1,-1); 147  divide(rt); 148     ans=n; 149     for(int i=1;i<=n;i++){ 150         ans=(ans+(ll)anss[i]*inv[i+1]*2%mod)%mod; 151  } 152     printf("%lld",(ll)ans*jc[n]%mod); 153     return 0; 154 }

AC代碼(100pts):

 1 #include<algorithm>
 2 #include<iostream>
 3 #include<cstring>
 4 #include<cstdio>
 5 #include<cmath>
 6 #include<queue>
 7 #define inf 2147483647
 8 #define eps 1e-9
 9 #define mod 1000000007
 10 using namespace std;  11 typedef long long ll;  12 typedef double db;  13 const db pi=acos(-1.0);  14 
 15 struct edge{  16     int v,next;  17 }a[200001];  18 int n,u,v,S,rt,mxd,bit,bitnum,tot=0,cnt=0,ans=0,jc[100001],inv[100001],anss[200001],tp[200001],num[200001],rev[200001],head[100001],mx[100001],siz[100001],dep[100001],dps[100001];  19 bool used[100001];  20 struct cp{  21  db a,b;  22  cp(){}  23  cp(db _a,db _b){  24         a=_a,b=_b;  25  }  26     friend cp operator +(cp a,cp b){return cp(a.a+b.a,a.b+b.b);}  27     friend cp operator -(cp a,cp b){return cp(a.a-b.a,a.b-b.b);}  28     friend cp operator *(cp a,cp b){return cp(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);}  29     friend cp operator *(cp a,db b){return cp(a.a*b,a.b*b);}  30     friend cp operator /(cp a,db b){return cp(a.a/b,a.b/b);}  31 }A[200001],B[200001],W[200001][2];  32 void _(){  33     for(int i=1;i<=(1<<17);i<<=1){  34         W[i][0]=cp(cos(pi/i),sin(pi/i));  35         W[i][1]=cp(cos(pi/i),-sin(pi/i));  36  }  37 }  38 void fft(cp *s,int op){  39     for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]);  40     for(int i=1;i<bit;i<<=1){  41         //cp w(cos(pi/i),op*sin(pi/i));
 42         cp w=W[i][op==-1];  43         for(int p=i<<1,j=0;j<bit;j+=p){  44             cp wk(1,0);  45             for(int k=j;k<i+j;k++,wk=wk*w){  46                 cp x=s[k],y=wk*s[k+i];  47                 s[k]=x+y;  48                 s[k+i]=x-y;  49  }  50  }  51  }  52     if(op==-1){  53         for(int i=0;i<bit;i++){  54             s[i]=s[i]/(db)bit;  55  }  56  }  57 }  58 void add(int u,int v){  59     a[++tot].v=v;  60     a[tot].next=head[u];  61     head[u]=tot;  62 }  63 void mul(int *ret,int *a,int *b,int n){  64     for(bit=1,bitnum=0;bit<=n*2;bit<<=1)bitnum++;  65     for(int i=1;i<bit;i++){  66         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bitnum-1));  67  }  68     for(int i=0;i<bit;i++){  69         A[i]=cp((db)a[i],0);  70         B[i]=cp((db)b[i],0);  71  }  72     fft(A,1);  73     fft(B,1);  74     for(int i=0;i<bit;i++)A[i]=A[i]*B[i];  75     fft(A,-1);  76     for(int i=0;i<bit;i++)ret[i]=(int)(A[i].a+0.5);  77 }  78 void getrt(int u,int fa){  79     mx[u]=0;  80     siz[u]=1;  81     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){  82         int v=a[tmp].v;  83         if(!used[v]&&v!=fa){  84  getrt(v,u);  85             siz[u]+=siz[v];  86             mx[u]=max(mx[u],siz[v]);  87  }  88  }  89     mx[u]=max(mx[u],S-mx[u]);  90     if(mx[u]<mx[rt])rt=u;  91 }  92 void getdep(int u,int fa,int dpt){  93     mxd=max(mxd,dpt);  94     dps[dpt]++;  95     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){  96         int v=a[tmp].v;  97         if(!used[v]&&v!=fa){  98             getdep(v,u,dpt+1);  99  } 100  } 101 } 102 void divide(int u){ 103     used[u]=true; 104     num[0]=1; 105     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 106         int v=a[tmp].v; 107         if(!used[v]){ 108             getdep(v,u,1); 109             for(int i=1;i<=mxd;i++){ 110                 num[i]+=dps[i]; 111                 tp[i]=dps[i]; 112                 dps[i]=0; 113  } 114             cnt=max(cnt,mxd); 115  mul(tp,tp,tp,mxd); 116             for(int i=1;i<=mxd*2;i++){ 117                 anss[i]-=tp[i]; 118                 tp[i]=0; 119  } 120             mxd=0; 121  } 122  } 123     for(int i=0;i<=cnt;i++){ 124         tp[i]=num[i]; 125         num[i]=0; 126  } 127  mul(tp,tp,tp,cnt); 128     for(int i=0;i<=cnt*2;i++){ 129         anss[i]+=tp[i]; 130         tp[i]=0; 131  } 132     cnt=0; 133     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){ 134         int v=a[tmp].v; 135         if(!used[v]){ 136             S=siz[v]; 137             rt=0; 138             getrt(v,0); 139  divide(rt); 140  } 141  } 142 } 143 int main(){ 144     memset(head,-1,sizeof(head)); 145  _(); 146     scanf("%d",&n); 147     jc[0]=inv[0]=inv[1]=1; 148     for(int i=2;i<=n+1;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; 149     for(int i=1;i<=n+1;i++)jc[i]=(ll)jc[i-1]*i%mod; 150     for(int i=1;i<n;i++){ 151         scanf("%d%d",&u,&v); 152  add(u,v); 153  add(v,u); 154  } 155     S=n; 156     mx[rt=0]=6666666; 157     getrt(1,-1); 158  divide(rt); 159     ans=n; 160     for(int i=1;i<=n;i++){ 161         ans=(ans+(ll)anss[i]*inv[i+1]%mod)%mod; 162  } 163     printf("%lld",(ll)ans*jc[n]%mod); 164     return 0; 165 }
相關文章
相關標籤/搜索