給定數列前k項\(h_0...h_{k-1}\),其後的項知足:\(h_i=\sum_{i=1}^kh_{i-j}a_i\),其中\(a_1...a_k\)是給定的係數,求\(h_n\)
數據範圍小的時候:函數
作法一:暴力\(O(nk)\)的DP優化
作法二:矩陣快速冪.ui
記\(H_i=\begin{bmatrix}h_i&h_{i+1}&...&h_{i+k-1}\end{bmatrix}\). 則\(h_n\)是\(H_{n-k+1}\)的最後一項。spa
\(H_{n-k+1}=H_0M^{n-k+1}\).net
其中\(M\)是轉移矩陣,如當\(k=4\)時是這麼填的:
\[ M=\begin{bmatrix} 0&0&0&a_4\\ 1&0&0&a_3\\ 0&1&0&a_2\\ 0&0&1&a_1 \end{bmatrix} \]
時間複雜度\(O(k^3lg n)\)
數據範圍大一些的時候:
\(k\leq2000,n\leq10^9\). 這時候矩陣快速冪也作不了了
仍是拿\(k=4\)時舉例,\(M\)的特徵多項式\(f(\lambda)\)爲:
\[ f(\lambda)=det(\lambda I-M)=\begin{bmatrix} \lambda&0&0&0\\ 0&\lambda&0&0\\ 0&0&\lambda&0\\ 0&0&0&\lambda \end{bmatrix} -\begin{bmatrix} 0&0&0&a_4\\ 1&0&0&a_3\\ 0&1&0&a_2\\ 0&0&1&a_1 \end{bmatrix}=\begin{bmatrix} \lambda&0&0&-a_4\\ -1&\lambda&0&-a_3\\ 0&-1&\lambda&-a_2\\ 0&0&-1&\lambda-a_1 \end{bmatrix} \]
用行列式的性質,將\(f(\lambda)\)按最後一列拉普拉斯展開,獲得以下,其中\((-1)^{i+j}f(x)_{i,j}\)即行列式定義裏的代數餘子式:
\[ \begin{aligned} f(\lambda)&=\sum_{i=1}^ka_{k-i+1}(-1)^{i+j}f(\lambda)_{i,j} &取j=k(按最後一列展開)\\ &=\sum_{i=1}^ka_{k-i+1}(-1)^{i+k}f(\lambda)_{i,k} \end{aligned} \]
化簡獲得以下式子(也能夠按\(k=4\)帶進去看看規律)
\[ f(\lambda)=\lambda^k-\sum_{i=1}^ka_i\lambda^{k-i} \]
如今明確一個定義,\(f(x)\)這個函數的自變量\(x\)能夠是實數,也能夠是矩陣等等。這個函數僅僅是表示如何將自變量組合起來。表達的意思也會多樣化,好比多項式、矩陣的多項式...下文會隨時切換自變量的種類,可是函數的本質不變。
\(\lambda\)是\(M\)的特徵值,是一個數。可是根據Cayley-Hamilton定理,若是把\(\lambda\)替換成\(M\)代入獲得\(f(M)=M^k-\sum_{i=1}^ka_iM^{k-i}\),結果爲一個零矩陣,即\(M^k-\sum_{i=1}^ka_iM^{k-i}=0\)
code
咱們想要求\(M\)的\(n\)次方(這裏的\(n\)只是表明\(M\)的\(n\)次方,題目中\(n\)應該用\(n-k+1\)替代),然而\(M^n\)直接快速冪求不現實,複雜度爲\(O(k^3lg n)\).blog
首先退一步考慮,要求一個數字的n次方\(x^n\),若是咱們把\(x^n\)對\(f(x)\)取模會發生什麼?get
根據多項式取模的定義,\(x^n \;\text{mod}\; f(x)=f(x)g(x)+r(x)\),其中\(g(x)\)和\(r(x)\)是兩個多項式.input
將\(x\)當作\(M\),那麼\(f(M)\)爲0.it
故\(M^n \;\text{mod}\; f(M)=r(M)\),且\(M^n=M^n \;\text{mod}\; f(M)\),那麼\(M_n=r(M)\)這個多項式
根據多項式取模的特性,\(r(x)\)的次數嚴格小於模數\(f(x)\)的次數\(k\). 那麼\(r(x)\)所包含的\(M\)的指數必定小於\(k\),到達了能夠計算的範圍。
要求\(M^n\),就只須要求\(M^n \;\text{mod}\; f(M)\)的多項式\(r(M)\)。若是兩個多項式\(A(x)\)和\(B(x)\)對模數取模分別獲得\(C(x)\)和\(D(x)\),那麼多項式\(A(x)B(x)\)對模數取模結果就是\(C(x)D(x)\)。
那麼就能夠用快速冪來求解\(M^n \;\text{mod}\; f(M)\)的結果了,也就是求出了\(r(x)\)的各項係數(記爲\(c_i\))。實際計算中,表面上是在計算\(M^n\),實際上計算的是\(M^n \;\text{mod}\; f(M)\)的結果。
至此求出\(r(x)=\sum\limits_{i=0}^{k-1}c_ix^i\). 將它當作矩陣的多項式代入\(M\),得\(r(M)=\sum\limits_{i=0}^{k-1}c_iM^i\)
因此\(M^n=\sum\limits_{i=0}^{k-1}c_iM^i\)
把\(n\)替換成題目所須要的\(n-k+1\),最終答案\(h_n\)爲\(H_0M^{n-k+1}\)的最後一項。
\[ H_0M^{n-k+1}=H_0\sum_{i=0}^{k-1}c_iM^i=\sum_{i=0}^{k-1}c_iH_0M_i=\sum_{i=0}^{k-1}c_iH_i \]
那麼要求的是\(H_0M^{n-k+1}\)的最後一項。記\(last(H_i)=h_{k+i}\) ,那麼
\[ h_n=last(H_0M^{n-k+1})=\sum_{i=0}^{k-1}c_ilast(H_i)=\sum_{i=0}^{k-1}c_ih_{i+k} \]
發現\(i+k\in[k,2k-1]\),因此暴力算出\(h_k...h_{2k-1}\),代入求解獲得\(h_n\),至此所有求完。
分析複雜度:多項式乘法此處用暴力算會比FFT快,耗時最多的集快速冪求\(r(x)\) ,複雜度爲\(O(k^2lgn)\)。
#include <cstdio> using namespace std; const int K=4005,mod=1e9+7; int n,k; int a[K],h[K]; int b[K],c[K],t[K],mo[K]; inline void add(int &x,int y){ x+=y; if(x>=mod) x-=mod; } void mul(int *x,int *y,int *z){ for(int i=0;i<=2*k-2;i++) t[i]=0; for(int i=0;i<k;i++) for(int j=0;j<k;j++) add(t[i+j],1LL*x[i]*y[j]%mod); for(int i=2*k-2;i>=k;i--){ for(int j=k-1;j>=0;j--) add(t[i-k+j],mod-1LL*t[i]*mo[j]%mod); t[i]=0; } for(int i=0;i<k;i++) z[i]=t[i]; } void ksm(int y){ for(;y;mul(b,b,b),y>>=1) if(y&1) mul(c,b,c); } int main(){ freopen("input.in","r",stdin); scanf("%d%d",&n,&k); n++; for(int i=1;i<=k;i++){ scanf("%d",&a[i]); if(a[i]<0) a[i]+=mod; } for(int i=1;i<=k;i++){ scanf("%d",&h[i]); if(h[i]<0) h[i]+=mod; } mo[k]=1; for(int i=1;i<=k;i++) mo[k-i]=mod-a[i]; if(n<=k){printf("%d\n",h[n]);return 0;} b[1]=1; c[0]=1; ksm(n-k); for(int i=k+1;i<=2*k;i++) for(int j=1;j<=k;j++) add(h[i],1LL*a[j]*h[i-j]%mod); int ans=0; for(int i=0;i<k;i++) add(ans,1LL*c[i]*h[i+k]%mod); printf("%d\n",ans); return 0; }
若是\(k\)也比較大,那麼要上多項式全家桶來優化多項式計算了!複雜度\(O(k\log k\log n)\)
來啊
#include <cstdio> #include <vector> #include <algorithm> using namespace std; typedef long long ll; typedef vector<int> vi; const int K=200005,mod=998244353,G=3; int n,k,a[K],h[K]; inline void swap(int &x,int &y){int t=x;x=y;y=t;} inline int max(int x,int y){return x>y?x:y;} inline int min(int x,int y){return x<y?x:y;} inline void add(int &x,int y){ y=(y%mod+mod)%mod; (x+=y)%=mod; } inline int pow(int x,int y){ int ret=1; for(;y;x=1LL*x*x%mod,y>>=1) if(y&1) ret=1LL*ret*x%mod; return ret; } namespace NTT{/*{{{*/ int n,invn,bit,rev[K*4],A[K*4],B[K*4],W[K*4][2]; void build(){ int bas=pow(G,mod-2); for(int i=0;i<=18;i++){ W[1<<i][0]=pow(G,(mod-1)/(1<<i)); W[1<<i][1]=pow(bas,(mod-1)/(1<<i)); } } void init(int na,int nb,vi &a,vi &b,int fn=0){ if(!fn) fn=na+nb; for(n=1,bit=0;n<fn;n<<=1,bit++); invn=pow(n,mod-2); for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1)); for(int i=0;i<n;i++) A[i]=B[i]=0; for(int i=0;i<na;i++) A[i]=a[i]; for(int i=0;i<nb;i++) B[i]=b[i]; } void ntt(int *a,int f){ for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]); int w_n,w,u,v; for(int i=2;i<=n;i<<=1){ w_n=W[i][f==-1]; for(int j=0;j<n;j+=i){ w=1; for(int k=0;k<i/2;k++){ u=a[j+k]; v=1LL*a[j+i/2+k]*w%mod; a[j+k]=(u+v)%mod; a[j+i/2+k]=(u+mod-v)%mod; w=1LL*w*w_n%mod; } } } if(f==1) return; for(int i=0;i<n;i++) a[i]=1LL*a[i]*invn%mod; } void calc(){ ntt(A,1); ntt(B,1); for(int i=0;i<n;i++) A[i]=1LL*A[i]*B[i]%mod; ntt(A,-1); } void calchh(){ ntt(A,1); ntt(B,1); for(int i=0;i<n;i++) A[i]=(2LL*B[i]%mod+mod-1LL*A[i]*B[i]%mod*B[i]%mod)%mod; ntt(A,-1); } }/*}}}*/ vi mop,b,c,T; vi operator - (vi A,vi B){ int n=A.size(),m=B.size(),fn=max(n,m); A.resize(fn); for(int i=0;i<m;i++) add(A[i],-B[i]); return A; } vi operator * (int a,vi A){ int n=A.size(); a=(a+mod)%mod; for(int i=0;i<n;i++) A[i]=1LL*a*A[i]%mod; return A; } vi operator * (vi &A,vi B){ int n=A.size(),m=B.size(); NTT::init(n,m,A,B); NTT::calc(); A.resize(n+m-1); for(int i=0;i<n+m-1;i++) A[i]=NTT::A[i]; return A; } vi inverse(vi A){ int n=A.size(); if(n==1){ A[0]=pow(A[0],mod-2); return A; } vi B=A; B.resize((n+1)/2); B=inverse(B); int m=B.size(); NTT::init(n,m,A,B,n+m-1+m-1); NTT::calchh(); B.resize(NTT::n); for(int i=0;i<NTT::n;i++) B[i]=NTT::A[i]; //B=(2*B)-((A*B)*B); B.resize(n); return B; } vi operator / (vi A,vi B){ int n=A.size()-1,m=B.size()-1; vi C; if(n<m){ C.resize(1); C[0]=0; return C; } reverse(A.begin(),A.end()); reverse(B.begin(),B.end()); B.resize(n-m+1); C=A*inverse(B); C.resize(n-m+1); reverse(C.begin(),C.end()); return C; } void module(vi &A,vi B){ int n=A.size()-1,m=B.size()-1; if(n<m) return; vi D=A/B; A=A-(B*D); A.resize(m); } void ksm(int y){ for(;y;y>>=1){ if(y&1){ c=c*b; module(c,mop); } b=b*b; module(b,mop); } } int main(){ freopen("input.in","r",stdin); NTT::build(); scanf("%d%d",&n,&k); n++; for(int i=1;i<=k;i++) scanf("%d",&h[i]),h[i]%=mod; for(int i=1;i<=k;i++) scanf("%d",&a[i]),a[i]%=mod; if(n<=k){printf("%d\n",h[n]);return 0;} mop.resize(k+1); mop[k]=1; for(int i=1;i<=k;i++) mop[k-i]=(mod-a[i])%mod; b.resize(2); b[1]=1; c.resize(1); c[0]=1; ksm(n-1); int ans=0; c.resize(k); for(int i=0;i<k;i++) add(ans,1LL*c[i]*h[i+1]%mod); printf("%d\n",ans); return 0; }
http://blog.csdn.net/qq_33229466/article/details/78933309 "ORZ"