神奇的思路,仍是要學習一個。c++
題意:給你一個字符串,並定義兩個前綴的lcs、兩個後綴的lcp,求式子膜\(2^{64}\)的值。
\[ \sum_{1\le i<j\le n} lcp(i,j)lcs(i,j)[lcp(i,j)\le k1][lcs(i,j)\le k2] \]
分析:數組
對於一對存在貢獻的\(<i,j>\),咱將它們的lcs、lcp拼起來,可知
\[ s[i-lcs(i,j)+1,i+lcp(i,j)-1]=s[j-lcs(i,j)+1,j+lcp(i,j)-1]\\ s[i-lcs(i,j)]\not=s[j-lcs(i,j)]\\ s[i+lcp(i,j)]\not=s[j+lcp(i,j)]\\ \]
這啓發咱們找出全部知足下列條件的子串對\(<i,j,len>\)
\[ s[i,i+len-1]=s[j,j+len-1],s[i-1]\not=s[j-1],s[i+len]\not=s[j+len] \]
能夠知道它的貢獻爲
\[ \sum_{\max(1,len-k2+1)}^{\min(len,k1)} k(len-k+1)=\sum_{k=1}^{min(len,k1)}k(len-k+1)-\sum_{k=1}^{\max(0,len-k2)}k(len-k+1) \]
因而考慮創建SA,並記錄後綴的前一個字符。學習
在height數組上從高到低啓發式合併,一邊統計答案。ui
#include <bits/stdc++.h> #define ull unsigned long long using namespace std; const int N=1e5+10; int n,k1,k2; char s[N]; int sa[N],ht[N],rc[N],c[N]; int lp[N],rp[N],bl[N],siz[N],cnt[N][26]; void buildSa() { int *x=ht,*y=rc,i,p,k,m=128; for(i=0; i<=m; ++i) c[i]=0; for(i=1; i<=n; ++i) c[x[i]=s[i]]++; for(i=1; i<=m; ++i) c[i]+=c[i-1]; for(i=n; i>=1; --i) sa[c[x[i]]--]=i; for(k=1; k<n; k<<=1) { for(i=n-k+1,p=0; i<=n; ++i) y[++p]=i; for(i=1; i<=n; ++i) if(sa[i]>k) y[++p]=sa[i]-k; for(i=0; i<=m; ++i) c[i]=0; for(i=1; i<=n; ++i) c[x[y[i]]]++; for(i=1; i<=m; ++i) c[i]+=c[i-1]; for(i=n; i>=1; --i) sa[c[x[y[i]]]--]=y[i]; swap(x,y), x[sa[1]]=p=1; for(i=2; i<=n; ++i) x[sa[i]]= y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k]?p:++p; if((m=p)>=n) break; } for(i=1; i<=n; ++i) rc[sa[i]]=i; for(i=1,k=0; i<=n; ++i) { p=sa[rc[i]-1]; if(k) k--; while(s[i+k]==s[p+k]) ++k; ht[rc[i]]=k; } // for(int i=1; i<=n; ++i) { // cout<<(s+sa[i]); // if(i>1) cout<<" "<<ht[i]; // cout<<endl; // } } pair<int,int> h[N]; ull sm(int x) {return (ull)x*(x+1)/2;} ull ssm(int x) {return (ull)x*(2*x+1)*(x+1)/6;} ull F(int x) { if(x>=k1+k2) return 0; ull s1=(ull)(x+1)*sm(min(x,k1))-ssm(min(x,k1)); ull s2=(ull)(x+1)*sm(max(0,x-k2))-ssm(max(0,x-k2)); return s1-s2; } ull f[N]; ull calc(int x,int y) { ull res=(ull)siz[x]*siz[y]; for(int i=0; i<26; ++i) res-=(ull)cnt[x][i]*cnt[y][i]; return res; } void merge(int x,int y) { for(int i=0; i<26; ++i) cnt[y][i]+=cnt[x][i]; for(int i=lp[x]; i<=rp[x]; ++i) bl[i]=y; lp[y]=min(lp[y],lp[x]); rp[y]=max(rp[y],rp[x]); siz[y]+=siz[x]; } int main() { scanf("%s%d%d",s+1,&k1,&k2); n=strlen(s+1); k1=min(k1,n); k2=min(k2,n); for(int i=1; i<=n; ++i) f[i]=F(i); buildSa(); for(int i=1; i<=n; ++i) { lp[i]=rp[i]=bl[i]=i; siz[i]=1; if(sa[i]>1) cnt[i][s[sa[i]-1]-'a']++; } for(int i=2; i<=n; ++i) h[i-1]=make_pair(-ht[i],i); sort(h+1,h+n); ull ans=0; for(int i=1; i<n; ++i) { int len=-h[i].first; int x=bl[h[i].second]; int y=bl[h[i].second-1]; if(siz[x]>siz[y]) swap(x,y); ans+=(ull)f[len]*calc(x,y); merge(x,y); // printf("%d,%d,%d,(%llu)\n",len,x,y,ans); } printf("%llu\n",ans); return 0; }