【2019北京集訓測試賽(十三)】函樹 虛樹

題目大意:給你一顆$n$個節點的樹,定義$d(x,y)=$點$x$到點$y$最短路上通過的邊數。c++

求$\sum\limits_{i=1}^{n} \sum\limits_{j=1}^{n} \varphi(i\times j)\times d(i,j)$ui

答案對998244353$取模。spa

 

咱們對這個式子作一些細微的處理,設最終的答案爲$ans$:code

$ans=\sum\limits_{i=1}^{n} \sum\limits_{j=1}^{n} \varphi(i\times j)\times d(i,j)$blog

$=\sum\limits_{i=1}^{n} \sum\limits_{j=1}^{n} \varphi(i)\varphi(j)\frac{gcd(i,j)}{\varphi(gcd(i,j))}\times d(i,j)$get

咱們設$F(d)=\sum\limits_{i=1}^{n} \sum\limits_{j=1,d|gcd(i,j)}^{n} \varphi(i)\varphi(j)\times d(i,j)$it

那麼,$ans=\sum\limits_{d=1}^{n} \frac{d}{\varphi(d)} \sum\limits_{p|d} F(p)\times G(\frac{d}{p})$ast

對於$G(x)$,設$x=\prod\limits_{i=1}^{k} p_i$,$p_i$是質數,$G(x)=(-1)^k$class

 

咱們考慮如何求$F(d)$。gc

顯然,咱們只須要把全部點權能被$d$整除的點找出來,建一棵虛樹,統計每條虛樹邊兩端的\$sum \varphi(i)$,把它們乘起來,再乘上虛樹邊邊長便可。

因爲點權等於編號,因此n棵虛樹的總點數是$O(n\ln\ n)$級別的,單次構建虛樹的複雜度是$O(size\log\ size)$的,因此並不會$T$掉。

而後就沒有了,總複雜度是$O(n\log^2\ n)$的。

  1 #include<bits/stdc++.h>
  2 #define M 100005
  3 #define MOD 998244353
  4 #define L long long
  5 using namespace std;
  6 
  7 int pri[M]={0},b[M]={0},phi[M]={0},zf[M]={0},Use=0;
  8 void init(){
  9     phi[1]=1; zf[1]=1;
 10     for(int i=2;i<M;i++){
 11         if(!b[i]) pri[++Use]=i,phi[i]=i-1,zf[i]=-1;
 12         for(int j=1;j<=Use&&i*pri[j]<M;j++){
 13             b[i*pri[j]]=1; zf[i*pri[j]]=-zf[i];
 14             if(i%pri[j]==0) {phi[i*pri[j]]=phi[i]*pri[j]; break;}
 15             phi[i*pri[j]]=phi[i]*(pri[j]-1);
 16         }
 17     }
 18 }
 19 
 20 L pow_mod(L x,L k){L ans=1; for(;k;k>>=1,x=x*x%MOD) if(k&1) ans=ans*x%MOD; return ans;}
 21 vector<int> G[M];
 22 
 23 struct edge{int u,v,next;}e[M*2]={0}; int head[M]={0},use=0;
 24 void add(int x,int y,int z){use++;e[use].u=y;e[use].v=z;e[use].next=head[x];head[x]=use;}
 25 int n,a[M]={0};
 26 
 27 int dep[M]={0},dfn[M]={0},low[M]={0},f[M][20]={0},t=0;
 28 void dfs(int x,int fa){
 29     dep[x]=dep[fa]+1; dfn[x]=++t; f[x][0]=fa;
 30     for(int i=1;i<20;i++) f[x][i]=f[f[x][i-1]][i-1];
 31     for(int i=0;i<G[x].size();i++) if(G[x][i]!=fa) dfs(G[x][i],x);
 32     low[x]=t;
 33 }
 34 int getlca(int x,int y){
 35     if(dep[x]<dep[y]) swap(x,y); int cha=dep[x]-dep[y];
 36     for(int i=19;~i;i--) if((1<<i)&cha) x=f[x][i];
 37     if(x==y) return x;
 38     for(int i=19;~i;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
 39     return f[x][0];
 40 }
 41 
 42 vector<int> D[M]; L F[M]={0};
 43 
 44 bool cmp(int x,int y){return dfn[x]<dfn[y];}
 45 int point[M]={0},stk[M]={0},is[M]={0},pcnt=0,cnt=0,nowt=0;
 46 void build(){
 47     pcnt=cnt; int siz=0; nowt=0;
 48     sort(point+1,point+cnt+1,cmp);
 49     for(int i=1;i<=cnt;i++){
 50         int last=0;
 51         while(siz&&getlca(stk[siz],point[i])!=stk[siz]) last=stk[siz],stk[siz--]=0;
 52         if(last){
 53             int lca=getlca(last,point[i]);
 54             if(lca!=stk[siz]){
 55                 stk[++siz]=lca;
 56                 point[++pcnt]=lca;
 57                 is[lca]=0;
 58             }
 59         }
 60         stk[++siz]=point[i]; is[point[i]]=1;
 61     }
 62     sort(point+1,point+pcnt+1,cmp);
 63     while(siz) stk[siz--]=0;
 64 }
 65 L sumphi[M]={0},sum=0;
 66 int dfs(int x){
 67     if(is[x]) sumphi[x]=phi[a[x]]; else sumphi[x]=0; 
 68     int v; nowt++;
 69     while(dfn[v=point[nowt]]<=low[x]&&nowt<=pcnt){
 70         dfs(v);
 71         sumphi[x]=(sumphi[x]+sumphi[v])%MOD;
 72     }
 73 }
 74 void getans(int x,L fsum){
 75     int v; nowt++; 
 76     while(dfn[v=point[nowt]]<=low[x]&&nowt<=pcnt){
 77         sum=(sum+1LL*(dep[v]-dep[x])*sumphi[v]%MOD*(fsum+sumphi[x]-sumphi[v]+MOD))%MOD;
 78         getans(v,(fsum+sumphi[x]-sumphi[v])%MOD);
 79     }
 80 }
 81 void solve(int x){
 82     while(pcnt)point[pcnt--]=0; cnt=sum=0;
 83     for(int i=0;i<D[x].size();i++){
 84         point[++cnt]=D[x][i];
 85     }
 86     build();
 87     nowt=1; dfs(point[1]);
 88     nowt=1; getans(point[1],0);
 89     F[x]=sum;
 90 }
 91 
 92 int main(){
 93 //    freopen("in.txt","r",stdin);
 94 //    freopen("out.txt","w",stdout);
 95     init();
 96     scanf("%d",&n);
 97     for(int i=1;i<=n;i++){
 98         a[i]=i; //scanf("%d",a+i);
 99         for(int j=1;j*j<=a[i];j++) if(a[i]%j==0){
100             D[j].push_back(i);
101             if(j*j!=a[i]) D[a[i]/j].push_back(i);
102         }
103     }
104     for(int i=1;i<n;i++){
105         int x,y; scanf("%d%d",&x,&y);
106         G[x].push_back(y); G[y].push_back(x);
107     }
108     dfs(1,0);
109     for(int i=1;i<=n;i++) solve(i);
110     for(int i=n;i;i--){
111         for(int j=i*2;j<=n;j+=i)
112         F[i]=(F[i]-F[j]+MOD)%MOD;
113     }
114     L ans=0;
115     for(int d=1;d<=n;d++)
116     ans=(ans+F[d]*d%MOD*pow_mod(phi[d],MOD-2))%MOD;
117     cout<<ans*2%MOD<<endl;
118     //cout<<ans*2*pow_mod(1LL*n*(n-1)%MOD,MOD-2)%MOD<<endl;
119 }
相關文章
相關標籤/搜索