#include<bits/stdc++.h> using namespace std; #define ll long long const int MAXN=4e6+10; inline int read() { int x=0; bool pos=1; char ch=getchar(); for(; !isdigit(ch); ch=getchar()) if(ch=='-') pos=0; for(; isdigit(ch); ch=getchar()) x=x*10+ch-'0'; return pos?x:-x; } struct cp { double x,y; cp(double xx=0,double yy=0) { x=xx; y=yy; } cp operator +(const cp &b) const { return cp(x+b.x,y+b.y); } cp operator -(const cp &b) const { return cp(x-b.x,y-b.y); } cp operator *(const cp &b) const { return cp(x*b.x-y*b.y,x*b.y+y*b.x); } }; const double PI=acos(-1.0); int rev[MAXN]; void init(int n,int lim) { for(int i=0; i<n; ++i) { for(int j=0; j<lim; ++j) if((i>>j)&1) rev[i]|=1<<(lim-j-1); } } void FFT(cp *a,int n,bool invflag) { for(int i=0; i<n; ++i) { if(i<rev[i]) swap(a[i],a[rev[i]]); } for(int l=2; l<=n; l<<=1) { int m=l>>1; cp wi=cp(cos(2*PI/l),sin(2*PI/l)); if(invflag) wi=cp(cos(2*PI/l),-sin(2*PI/l)); for(cp *p=a; p!=a+n; p+=l) { cp w=cp(1,0); for(int i=0; i<m; ++i) { cp t=w*p[i+m]; p[i+m]=p[i]-t; p[i]=p[i]+t; w=w*wi; } } } if(invflag) { for(int i=0; i<n; ++i) a[i].x/=n,a[i].y/=n; } } int n,m; cp a[MAXN],b[MAXN]; int main() { n=read(),m=read(); ++n,++m; for(int i=0; i<n; ++i) a[i]=cp((double)read(),0); for(int i=0; i<m; ++i) b[i]=cp((double)read(),0); int N=1,lim=0; while(N<n+m-1) N<<=1,++lim; init(N,lim); FFT(a,N,false); FFT(b,N,false); for(int i=0; i<N; ++i) a[i]=a[i]*b[i]; FFT(a,N,true); for(int i=0; i<n+m-1; ++i) printf("%d ",(int)(a[i].x+0.5)); return 0; }
#include<bits/stdc++.h> using namespace std; #define ll long long const int MAXN=4e6+10; inline int read() { int x=0; bool pos=1; char ch=getchar(); for(; !isdigit(ch); ch=getchar()) if(ch=='-') pos=0; for(; isdigit(ch); ch=getchar()) x=x*10+ch-'0'; return pos?x:-x; } const int P=998244353,G=3; int add(int a,int b) { return (a + b) % P; } int mul(int a,int b) { return 1LL * a * b % P; } int fpow(int a,int b) { int res=1; while(b) { if(b&1) res=mul(res,a); a=mul(a,a); b>>=1; } return res; } int inv(int x) { return fpow(x,P-2); } int rev[MAXN]; void init(int n,int lim) { for(int i=0; i<n; ++i) { for(int j=0; j<lim; ++j) if((i>>j)&1) rev[i]|=1<<(lim-j-1); } } void NTT(int *a,int n,bool invflag) { for(int i=0; i<n; ++i) { if(i<rev[i]) swap(a[i],a[rev[i]]); } for(int l=2; l<=n; l<<=1) { int gi=fpow(G,(P-1)/l); if(invflag) gi=inv(gi); int m=l>>1; for(int *p=a;p!=a+n;p+=l) { int g=1; for(int i=0; i<m; ++i) { int t=mul(g,p[i+m]); p[i+m]=add(p[i],P-t); p[i]=add(p[i],t); g=mul(g,gi); } } } if(invflag) { int Invn=inv(n); for(int i=0; i<n; ++i) a[i]=mul(a[i],Invn); } } int n,m; int a[MAXN],b[MAXN]; int main() { n=read(),m=read(); ++n,++m; for(int i=0; i<n; ++i) a[i]=read(); for(int i=0; i<m; ++i) b[i]=read(); int N=1,lim=0; while(N<n+m-1) N<<=1,++lim; init(N,lim); NTT(a,N,false); NTT(b,N,false); for(int i=0; i<N; ++i) a[i]=mul(a[i],b[i]); NTT(a,N,true); for(int i=0; i<n+m-1; ++i) printf("%d ",a[i]); return 0; }
#include<bits/stdc++.h> using namespace std; typedef long long ll; inline int read() { int out=0,fh=1; char jp=getchar(); while ((jp>'9'||jp<'0')&&jp!='-') jp=getchar(); if (jp=='-') fh=-1,jp=getchar(); while (jp>='0'&&jp<='9') out=out*10+jp-'0',jp=getchar(); return out*fh; } const int P=998244353,inv2=(P+1)>>1; inline int add(int a,int b) { return (a + b) % P; } inline int mul(int a,int b) { return 1LL * a * b % P; } void FWT(int *a,int n,int op) { for(int l=2;l<=n;l<<=1) { int m=l>>1; for(int *p=a;p!=a+n;p+=l) for(int i=0;i<m;++i) { int x0=p[i],x1=p[i+m]; if(op==1)//or p[i]=x0,p[i+m]=add(x0,x1); else if(op==2)//and p[i]=add(x0,x1),p[i+m]=x1; else//xor p[i]=add(x0,x1),p[i+m]=add(x0,P-x1); } } } void IFWT(int *a,int n,int op) { for(int l=2;l<=n;l<<=1) { int m=l>>1; for(int *p=a;p!=a+n;p+=l) for(int i=0;i<m;++i) { int x0=p[i],x1=p[i+m]; if(op==1) p[i]=x0,p[i+m]=add(x1,P-x0); else if(op==2) p[i]=add(x0,P-x1),p[i+m]=x1; else p[i]=mul(add(x0,x1),inv2),p[i+m]=mul(add(x0,P-x1),inv2); } } } const int MAXN=(1<<17)+10; int a[MAXN],b[MAXN],c[MAXN]; int main() { int n=read(); n=1<<n; for(int i=0;i<n;++i) a[i]=read(); for(int i=0;i<n;++i) b[i]=read(); for(int op=1;op<=3;++op) { FWT(a,n,op); FWT(b,n,op); for(int i=0;i<n;++i) c[i]=mul(a[i],b[i]); IFWT(a,n,op); IFWT(b,n,op); IFWT(c,n,op); for(int i=0;i<n;++i) printf("%d ",c[i]); puts(""); } return 0; }
如今要算這樣一個形式的卷積:
\[ f_i=\sum_{j=1}^{i}f_{i-j}\cdot g_j \]c++
這樣作後半段必定是對的,因此就不用關注長度的問題了.時間複雜度爲 \(O(nlog^2n)\) .git
#include<bits/stdc++.h> using namespace std; #define ll long long inline ll read() { ll out=0,fh=1; char jp=getchar(); while ((jp>'9'||jp<'0')&&jp!='-') jp=getchar(); if (jp=='-') fh=-1,jp=getchar(); while (jp>='0'&&jp<='9') out=out*10+jp-'0',jp=getchar(); return out*fh; } const int P=998244353,G=3; inline int add(int a,int b) { return (a + b) % P; } inline int mul(int a,int b) { return 1LL * a * b % P; } int fpow(int a,int b) { int res=1; while(b) { if(b&1) res=mul(res,a); a=mul(a,a); b>>=1; } return res; } const int MAXN=2e5+10; int rev[MAXN],omega[MAXN],inv[MAXN]; void rev_init(int n,int lim) { for(int i=0;i<n;++i) { int t=0; for(int j=0;j<lim;++j) if((i>>j)&1) t|=1<<(lim-j-1); rev[i]=t; } } void omega_init(int n) { for(int l=2;l<=n;l<<=1) { omega[l]=fpow(G,(P-1)/l); inv[l]=fpow(omega[l],P-2); } } void NTT(int *a,int n,int lim,int invflag) { for(int i=0;i<n;++i) { if(i<rev[i]) swap(a[i],a[rev[i]]); } for(int l=2;l<=n;l<<=1) { int m=l>>1; int gi=omega[l]; if(invflag) gi=inv[l]; for(int *p=a;p!=a+n;p+=l) { int g=1; for(int i=0;i<m;++i) { int t=mul(g,p[i+m]); p[i+m]=add(p[i],P-t); p[i]=add(p[i],t); g=mul(g,gi); } } } if (invflag) { int invn=fpow(n,P-2); for(int i=0;i<n;++i) a[i]=mul(a[i],invn); } } int f[MAXN],g[MAXN],a[MAXN],b[MAXN]; void solve(int l,int r,int lim)//[l,r) { if(lim<=0) return; int mid=(l+r)>>1,len=r-l; solve(l,mid,lim-1); rev_init(len,lim); for(int i=0;i<len/2;++i) a[i]=f[l+i]; for(int i=len/2;i<len;++i) a[i]=0; for(int i=0;i<len;++i) b[i]=g[i]; NTT(a,len,lim,false); NTT(b,len,lim,false); for(int i=0;i<len;++i) a[i]=mul(a[i],b[i]); NTT(a,len,lim,true); for(int i=len/2;i<len;++i) f[i+l]=add(f[i+l],a[i]); solve(mid,r,lim-1); } int main() { int n=read(); f[0]=1; for(int i=1;i<n;++i) g[i]=read(); int N=1,lim=0; while(N<n) N<<=1,++lim; omega_init(N); solve(0,N,lim); for(int i=0;i<n;++i) printf("%d ",f[i]); puts(""); return 0; }