給一棵樹,每條邊上有一個字符,求有多少對 \((x,y)(x<y)\),知足 \(x\) 到 \(y\) 路徑上的邊上的字符按順序組成的字符串爲迴文串。數組
\(1\leq n\leq 50000,1\leq x_i,y_i\leq n,z_i\in\{0,1\}\)數據結構
觀察一條通過重心的迴文串是長什麼樣的dom
\(S\) 是一個任意的字符串,\(T\) 是一個迴文串。spa
建出根到每一個節點對應的串的AC自動機。code
那麼 \(x\) 這邊的 \(S\) 串就是 \(x\) 對應的AC自動機節點的一個後綴, \(T\) 串是一個前綴。blog
dfs 整棵樹的 fail 樹,先統計每一個點做爲 \(x\) 點的貢獻,再把做爲 \(y\) 點的貢獻加到數據結構中。字符串
開 \(\sqrt n\) 個長度爲 \(\sqrt n\) 的數組 \(c_{1,\sqrt n}\)。\(c_{i,j}\) 表示當前節點有多少個長度 \(\bmod i=j\) 的祖先。get
當一個點是 \(y\) 點的時候,令對應長度的字符串的出現次數 \(+1\),還要對於 \(\leq \sqrt n\) 的全部數 \(i\),令 \(c_{i,\lvert S \rvert \bmod i}++\)。string
當一個點是 \(x\) 點的時候,一個迴文串的全部迴文前綴能夠被表示爲 \(O(\log n)\) 個等差數列,公差 \(\leq \sqrt n\) 的那部分在 \(c\) 裏面查,剩下的暴力查就行了。it
記一個等差數列的首項爲 \(a_1\),公差爲 \(d\),末項爲 \(a_n\),那麼貢獻就是 dfs 到深度爲 \(a_n\) 的點時 \(c_{d,a_1\bmod d}\) 的值減掉 dfs 到深度爲 \(a_1-d\) 的點時 \(c_{d,a_1\bmod d}\) 的值。
先 dfs 一遍把全部詢問的信息插到 vector 中,再 dfs 一遍計算答案。
求一個串的全部迴文前綴能夠直接哈希。
時間複雜度:\(f(n)=O(n^\frac{3}{2})+O(n\log^2 n)=O(n^\frac{3}{2})\)
\(T(n)=2T(\frac{n}{2})+f(n)=2T(\frac{n}{2})+O(n^\frac{3}{2})=O(n^\frac{3}{2})\)
把這份代碼中的後綴自動機換成 AC自動機,迴文自動機換成哈希就行了。
#include<cstdio> #include<cstring> #include<algorithm> #include<cstdlib> #include<ctime> #include<utility> #include<functional> #include<cmath> #include<vector> #include<queue> #include<assert.h> //using namespace std; using std::min; using std::max; using std::swap; using std::sort; using std::reverse; using std::random_shuffle; using std::lower_bound; using std::upper_bound; using std::unique; using std::vector; using std::queue; typedef long long ll; typedef unsigned long long ull; typedef double db; typedef std::pair<int,int> pii; typedef std::pair<ll,ll> pll; void open(const char *s){ #ifndef ONLINE_JUDGE char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout); #endif } void open2(const char *s){ #ifdef DEBUG char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout); #endif } int rd(){int s=0,c,b=0;while(((c=getchar())<'0'||c>'9')&&c!='-');if(c=='-'){c=getchar();b=1;}do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9');return b?-s:s;} void put(int x){if(!x){putchar('0');return;}static int c[20];int t=0;while(x){c[++t]=x%10;x/=10;}while(t)putchar(c[t--]+'0');} int upmin(int &a,int b){if(b<a){a=b;return 1;}return 0;} int upmax(int &a,int b){if(b>a){a=b;return 1;}return 0;} const int N=50010; vector<pii> g[N]; int sz[N]; int totsz,rt,rtsz; int b[N]; int n; int f[N]; ll* ss[N]; ll ss2[N]; ll ans=0; int _log[N]; struct info { int x; int y; int z; info(int a=0,int b=0,int c=0):x(a),y(b),z(c){} }; int cmp(info a,info b) { if(a.x!=b.x) return a.x<b.x; return a.z<b.z; } void dfs1(int x,int fa) { sz[x]=1; for(auto v:g[x]) if(v.first!=fa&&!b[v.first]) { dfs1(v.first,x); sz[x]+=sz[v.first]; } } void dfs2(int x,int fa) { int mx=totsz-sz[x]; for(auto v:g[x]) if(v.first!=fa&&!b[v.first]) { dfs2(v.first,x); mx=max(mx,sz[v.first]); } if(mx<rtsz) { rtsz=mx; rt=x; } } void dfs3(int x,int fa) { f[x]=fa; for(auto v:g[x]) if(v.first!=fa&&!b[v.first]) dfs3(v.first,x); } int tot; int str[N]; namespace sam { int next[2*N][2]; int fail[2*N]; int len[2*N]; int last,cnt; int b[2*N]; int a[2*N][2]; int s[2*N]; void init() { while(cnt) { next[cnt][0]=next[cnt][1]=0; a[cnt][0]=a[cnt][1]=0; b[cnt]=0; s[cnt]=0; cnt--; } cnt=1; last=1; } int insert(int p,int c) { if(next[p][c]) { last=next[p][c]; s[last]++; return last; } // int p=last; int np=++cnt; len[np]=len[p]+1; s[np]=1; for(;p&&!next[p][c];p=fail[p]) next[p][c]=np; if(!p) fail[np]=1; else { int q=next[p][c]; if(len[q]==len[p]+1) fail[np]=q; else { int nq=++cnt; len[nq]=len[p]+1; memcpy(next[nq],next[q],sizeof next[q]); fail[nq]=fail[q]; fail[q]=fail[np]=nq; for(;p&&next[p][c]==q;p=fail[p]) next[p][c]=nq; } } return last=np; } } namespace pam { int next[N][2]; int trans[N][2]; int fail[N]; int len[N]; int diff[N]; int link[N]; int top[N]; int last; int cnt; void init() { while(cnt>=0) { next[cnt][0]=next[cnt][1]=0; trans[cnt][0]=trans[cnt][1]=0; cnt--; } cnt=1; str[0]=-1; fail[0]=1; fail[1]=0; len[0]=0; len[1]=-1; last=0; link[0]=0; diff[0]=1; diff[1]=0; top[0]=0; top[1]=1; trans[0][0]=trans[0][1]=trans[1][0]=trans[1][1]=1; } int find(int x,int c) { return str[tot-len[x]-1]==c?x:trans[x][c]; } void insert(int c) { str[++tot]=c; last=find(last,c); int now=last; if(!next[now][c]) { int cur=++cnt; len[cur]=len[now]+2; last=find(fail[last],c); fail[cur]=next[last][c]; diff[cur]=len[cur]-len[fail[cur]]; if(diff[cur]==diff[fail[cur]]) { link[cur]=link[fail[cur]]; top[cur]=top[fail[cur]]; } else { link[cur]=fail[cur]; top[cur]=cur; } if(!link[cur]) link[cur]=cur; memcpy(trans[cur],trans[fail[cur]],sizeof trans[cur]); trans[cur][str[tot-len[fail[cur]]]]=fail[cur]; next[now][c]=cur; } last=next[now][c]; } } namespace trie { int a[N][2]; int s[N]; int cnt; void clear() { while(cnt) { a[cnt][0]=a[cnt][1]=0; s[cnt]=0; cnt--; } cnt=1; } } ll s,s2; int pos[N]; int pos2[N]; int pos3[N]; int pos4[N]; int q[N]; int len[N],id[N],top; int head,tail; vector<int> e[2*N]; int sq; vector<info> h[2*N]; int orzzjt,orzzjt2; void bfs(int x) { sam::init(); // sam::s[1]=1; pos[x]=1; head=1; tail=0; q[++tail]=x; trie::clear(); pos4[x]=1; while(tail>=head) { int y=q[head++]; s+=trie::s[pos4[y]]; trie::s[pos4[y]]++; for(auto v:g[y]) if(!b[v.first]&&v.first!=f[y]) { pos[v.first]=sam::insert(pos[y],v.second); q[++tail]=v.first; if(trie::a[pos4[y]][v.second]) pos4[v.first]=trie::a[pos4[y]][v.second]; else pos4[v.first]=trie::a[pos4[y]][v.second]=++trie::cnt; } } } void dfs(int x,int fa) { for(int y=pos[x];y!=1&&!sam::b[y];y=sam::fail[y]) { sam::a[sam::fail[y]][str[tot-sam::len[sam::fail[y]]]]=y; sam::b[y]=1; } //這樣建出來的後綴樹不是完整的,但已經夠用了 int now=pam::last; pos2[x]=now; if(pam::len[now]==tot) { if(fa) s2++; pos3[x]=now; } else pos3[x]=pos3[fa]; for(auto v:g[x]) if(!b[v.first]&&v.first!=fa) { pam::last=now; pam::insert(v.second); dfs(v.first,x); tot--; } } void dfs4(int x) { len[++top]=sam::len[x]; id[top]=x; for(auto v:e[x]) for(int y=pos3[v];y>1;) if(pam::diff[y]<=sq) { h[id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[y]-pam::diff[y])-len]].push_back(info(sam::len[x]-pam::len[y]-pam::diff[y],pam::diff[y],-1)); h[id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[pam::link[y]])-len]].push_back(info(sam::len[x]-pam::len[pam::link[y]],pam::diff[y],1)); //h.push_back(info(sam::len[x]-pam::len[y],id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[y])-len],1)); // h.push_back(info(sam::len[x]-pam::len[pam::link[y]]+pam::diff[y],id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[pam::link[y]]+pam::diff[y])-len],-1)); y=pam::fail[pam::link[y]]; orzzjt2+=_log[top]; } else { y=pam::fail[y]; } if(sam::a[x][0]) dfs4(sam::a[x][0]); if(sam::a[x][1]) dfs4(sam::a[x][1]); top--; } void dfs5(int x) { for(auto v:h[x]) if(v.x>=0&&v.x!=sam::len[x]) s+=ss[v.y][v.x%v.y]*v.z; orzzjt+=sq; for(int i=1;i<=sq;i++) ss[i][sam::len[x]%i]+=sam::s[x]; ss2[sam::len[x]]+=sam::s[x]; for(auto v:h[x]) if(v.x>=0&&v.x==sam::len[x]) s+=ss[v.y][v.x%v.y]*v.z; for(auto v:e[x]) for(int y=pos3[v];y>1;) if(pam::diff[y]<=sq) { y=pam::fail[pam::link[y]]; } else { s+=ss2[sam::len[x]-pam::len[y]]; y=pam::fail[y]; } if(sam::a[x][0]) dfs5(sam::a[x][0]); if(sam::a[x][1]) dfs5(sam::a[x][1]); for(int i=1;i<=sq;i++) ss[i][sam::len[x]%i]-=sam::s[x]; ss2[sam::len[x]]-=sam::s[x]; } ll calc(int x) { s=0; s2=0; bfs(x); pam::init(); dfs(x,0); for(int i=1;i<=sam::cnt;i++) { e[i].clear(); h[i].clear(); } for(int i=1;i<=tail;i++) e[pos[q[i]]].push_back(q[i]); dfs4(1); // for(int i=1;i<=sam::cnt;i++) // sort(h[i].begin(),h[i].end()); dfs5(1); return s; } int c[N],c2[N]; int t; vector<pii> g2; void solve(int x) { dfs1(x,0); totsz=sz[x]; rtsz=0x7fffffff; dfs2(x,0); x=rt; dfs3(x,0); int t=0; sq=sqrt(totsz); // sq=0; ans+=calc(x); ans+=s2; for(auto v:g[x]) if(!b[v.first]) { b[v.first]=1; c[++t]=v.first; c2[t]=v.second; } g2=g[x]; g[x].clear(); for(int i=1;i<=t;i++) { b[c[i]]=0; g[x].clear(); g[x].push_back(pii(c[i],c2[i])); ans-=calc(x); b[c[i]]=1; } g[x]=g2; for(int i=1;i<=t;i++) b[c[i]]=0; b[x]=1; for(auto v:g[x]) if(!b[v.first]) solve(v.first); } int main() { open("string"); scanf("%d",&n); for(int i=1;i<=n;i++) for(int j=1,k=0;j<=n;j<<=1,k++) _log[i]=k; int _sqrt=sqrt(n); for(int i=1;i<=_sqrt;i++) { ss[i]=new ll[i]; for(int j=0;j<i;j++) ss[i][j]=0; } int x,y,z; for(int i=1;i<n;i++) { scanf("%d%d%d",&x,&y,&z); g[x].push_back(pii(y,z)); g[y].push_back(pii(x,z)); } solve(1); // assert(ans%2==0); // ans/=2; printf("%lld\n",ans); // printf("%d\n",orzzjt); // printf("%d\n",orzzjt2); return 0; }