題意:一棵n個點的樹,點有點權。定義$G(a,b)$表示:咱們將樹上從a走到b通過的點都拿出來,設這些點的點權分別爲$z_0,z_1...z_{l-1}$,則$G(a,b)=z_0+z_1k^1+z_2k^2+...+z_{l-1}k^{l-1}$。若是$G(a,b)=X \mod Y$(保證Y是質數),則咱們稱(a,b)是好的,不然是壞的。如今想知道,有多少個三元組(a,b,c),知足(a,b),(b,c),(a,c)都是好的或者都是壞的?node
$n\le 10^5,Y\le 10^9$ios
題解:因爲一個點對要麼是好的要麼是壞的,因此咱們能夠枚舉一下全部符合條件的3元組的狀況。不過符合條件須要3條邊都相同,那咱們能夠反過來,統計不合法的3元組的狀況(一共$2^3-2$種狀況)。通過觀察咱們發現,咱們能夠在 同時鏈接兩種顏色的邊 的那個點處統計貢獻,即把三元組的貢獻放到了點上。咱們設$in_0(),in_1(i),out_0(i),out_1(i)$表示i有多少個好(壞)邊連入(出),則一個點對答案的貢獻就變成:spa
$2in_0(i)in_1(i)+2out_0(i)out_1(i)+in_0(i)out_1(i)+in_1(i)out_0(i)$blog
最後將答案/2便可。get
因此如今咱們只須要求:對於每一個點,有多少好邊連入(連出)。這個用點分治能夠搞定,由於咱們容易計算兩個多項式鏈接起來的結果。本題我採用的是容斥式的點分治。string
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const int maxn=100010; typedef long long ll; int n,cnt,tot,mn,rt; ll X,Y,K,Ki,ans; ll pw[maxn],pi[maxn],v[maxn],in1[maxn],in0[maxn],out1[maxn],out0[maxn]; int to[maxn<<1],nxt[maxn<<1],head[maxn],vis[maxn],siz[maxn]; struct node { ll x; int y; node() {} node(ll a,int b) {x=a,y=b;} bool operator < (const node &a) const {return x<a.x;} }p[maxn],q[maxn]; inline int rd() { char gc=getchar(); int ret=0; while(gc<'0'||gc>'9') gc=getchar(); while(gc>='0'&&gc<='9') ret=ret*10+gc-'0',gc=getchar(); return ret; } inline void add(int a,int b) { to[cnt]=b,nxt[cnt]=head[a],head[a]=cnt++; } inline ll pm(ll x,ll y) { ll z=1; while(y) { if(y&1) z=z*x%Y; x=x*x%Y,y>>=1; } return z; } void getrt(int x,int fa) { int i,tmp=0; siz[x]=1; for(i=head[x];i!=-1;i=nxt[i]) if(!vis[to[i]]&&to[i]!=fa) getrt(to[i],x),siz[x]+=siz[to[i]],tmp=max(tmp,siz[to[i]]); tmp=max(tmp,n-siz[x]); if(tmp<mn) mn=tmp,rt=x; } void getp(int x,int fa,int dep,ll s1,ll s2) { s1=(s1*K+v[x])%Y,s2=(s2+v[x]*((!dep)?0:pw[dep-1]))%Y,dep++; p[++tot]=node((X-s1+Y)*pi[dep]%Y,x),q[tot]=node(s2,x); for(int i=head[x];i!=-1;i=nxt[i]) if(!vis[to[i]]&&to[i]!=fa) getp(to[i],x,dep,s1,s2); } void calc(int x,int flag,int dep,ll s1,ll s2) { int i,j,cnt; tot=0; s1=(s1*K+v[x])%Y,s2=(s2+v[x]*((!dep)?0:pw[dep-1]))%Y,dep++; p[++tot]=node((X-s1+Y)*pi[dep]%Y,x),q[tot]=node(s2,x); for(i=head[x];i!=-1;i=nxt[i]) if(!vis[to[i]]) getp(to[i],x,dep,s1,s2); sort(p+1,p+tot+1),sort(q+1,q+tot+1); for(cnt=0,i=j=1;i<=tot;i++) { for(;j<=tot&&q[j].x<=p[i].x;j++) { if(j==1||q[j].x!=q[j-1].x) cnt=0; cnt++; } if(j!=1&&q[j-1].x==p[i].x) out1[p[i].y]+=cnt*flag; } for(cnt=0,i=j=1;i<=tot;i++) { for(;j<=tot&&p[j].x<=q[i].x;j++) { if(j==1||p[j].x!=p[j-1].x) cnt=0; cnt++; } if(j!=1&&p[j-1].x==q[i].x) in1[q[i].y]+=cnt*flag; } } void dfs(int x) { vis[x]=1; int i; calc(x,1,0,0,0); for(i=head[x];i!=-1;i=nxt[i]) if(!vis[to[i]]) { calc(to[i],-1,1,v[x],0); tot=siz[to[i]],mn=1<<30,getrt(to[i],x),dfs(rt); } } int main() { //freopen("cf434E.in","r",stdin); n=rd(),Y=rd(),K=rd(),X=rd(),Ki=pm(K,Y-2); int i,a,b; memset(head,-1,sizeof(head)); for(i=1;i<=n;i++) v[i]=rd(); for(i=pw[0]=pi[0]=1;i<=n;i++) pw[i]=pw[i-1]*K%Y,pi[i]=pi[i-1]*Ki%Y; for(i=1;i<n;i++) a=rd(),b=rd(),add(a,b),add(b,a); tot=n,mn=1<<30,getrt(1,0),dfs(rt); for(i=1;i<=n;i++) { in0[i]=n-in1[i],out0[i]=n-out1[i]; ans+=2*in1[i]*in0[i]+2*out1[i]*out0[i]+in0[i]*out1[i]+in1[i]*out0[i]; } printf("%lld",1ll*n*n*n-ans/2); return 0; }