多項式

  • 並無講解.只是記錄用.

\(FFT\)

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int MAXN=4e6+10;
inline int read()
{
    int x=0;
    bool pos=1;
    char ch=getchar();
    for(; !isdigit(ch); ch=getchar())
        if(ch=='-')
            pos=0;
    for(; isdigit(ch); ch=getchar())
        x=x*10+ch-'0';
    return pos?x:-x;
}
struct cp
{
    double x,y;
    cp(double xx=0,double yy=0)
    {
        x=xx;
        y=yy;
    }
    cp operator +(const cp &b) const
    {
        return cp(x+b.x,y+b.y);
    }
    cp operator -(const cp &b) const
    {
        return cp(x-b.x,y-b.y);
    }
    cp operator *(const cp &b) const
    {
        return cp(x*b.x-y*b.y,x*b.y+y*b.x);
    }
};
const double PI=acos(-1.0);
int rev[MAXN];
void init(int n,int lim)
{
    for(int i=0; i<n; ++i)
    {
        for(int j=0; j<lim; ++j)
            if((i>>j)&1)
                rev[i]|=1<<(lim-j-1);
    }
}
void FFT(cp *a,int n,bool invflag)
{
    for(int i=0; i<n; ++i)
    {
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
    }
    for(int l=2; l<=n; l<<=1)
    {
        int m=l>>1;
        cp wi=cp(cos(2*PI/l),sin(2*PI/l));
        if(invflag)
            wi=cp(cos(2*PI/l),-sin(2*PI/l));
        for(cp *p=a; p!=a+n; p+=l)
        {
            cp w=cp(1,0);
            for(int i=0; i<m; ++i)
            {
                cp t=w*p[i+m];
                p[i+m]=p[i]-t;
                p[i]=p[i]+t;
                w=w*wi;
            }
        }
    }
    if(invflag)
    {
        for(int i=0; i<n; ++i)
            a[i].x/=n,a[i].y/=n;
    }
}
int n,m;
cp a[MAXN],b[MAXN];
int main()
{
    n=read(),m=read();
    ++n,++m;
    for(int i=0; i<n; ++i)
        a[i]=cp((double)read(),0);
    for(int i=0; i<m; ++i)
        b[i]=cp((double)read(),0);
    int N=1,lim=0;
    while(N<n+m-1)
        N<<=1,++lim;
    init(N,lim);
    FFT(a,N,false);
    FFT(b,N,false);
    for(int i=0; i<N; ++i)
        a[i]=a[i]*b[i];
    FFT(a,N,true);
    for(int i=0; i<n+m-1; ++i)
        printf("%d ",(int)(a[i].x+0.5));
    return 0;
}

\(NTT\)

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int MAXN=4e6+10;
inline int read()
{
    int x=0;
    bool pos=1;
    char ch=getchar();
    for(; !isdigit(ch); ch=getchar())
        if(ch=='-')
            pos=0;
    for(; isdigit(ch); ch=getchar())
        x=x*10+ch-'0';
    return pos?x:-x;
}
const int P=998244353,G=3;
int add(int a,int b)
{
    return (a + b) % P;
}
int mul(int a,int b)
{
    return 1LL * a * b % P;
}
int fpow(int a,int b)
{
    int res=1;
    while(b)
    {
        if(b&1)
            res=mul(res,a);
        a=mul(a,a);
        b>>=1;
    }
    return res;
}
int inv(int x)
{
    return fpow(x,P-2);
}
int rev[MAXN];
void init(int n,int lim)
{
    for(int i=0; i<n; ++i)
    {
        for(int j=0; j<lim; ++j)
            if((i>>j)&1)
                rev[i]|=1<<(lim-j-1);
    }
}
void NTT(int *a,int n,bool invflag)
{
    for(int i=0; i<n; ++i)
    {
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
    }
    for(int l=2; l<=n; l<<=1)
    {
        int gi=fpow(G,(P-1)/l);
        if(invflag)
            gi=inv(gi);
        int m=l>>1;
        for(int *p=a;p!=a+n;p+=l)
        {
            int g=1;
            for(int i=0; i<m; ++i)
            {
                int t=mul(g,p[i+m]);
                p[i+m]=add(p[i],P-t);
                p[i]=add(p[i],t);
                g=mul(g,gi);
            }
        }
    }
    if(invflag)
    {
        int Invn=inv(n);
        for(int i=0; i<n; ++i)
            a[i]=mul(a[i],Invn);
    }
}
int n,m;
int a[MAXN],b[MAXN];
int main()
{
    n=read(),m=read();
    ++n,++m;
    for(int i=0; i<n; ++i)
        a[i]=read();
    for(int i=0; i<m; ++i)
        b[i]=read();
    int N=1,lim=0;
    while(N<n+m-1)
        N<<=1,++lim;
    init(N,lim);
    NTT(a,N,false);
    NTT(b,N,false);
    for(int i=0; i<N; ++i)
        a[i]=mul(a[i],b[i]);
    NTT(a,N,true);
    for(int i=0; i<n+m-1; ++i)
        printf("%d ",a[i]);
    return 0;
}

\(FWT\)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read()
{
    int out=0,fh=1;
    char jp=getchar();
    while ((jp>'9'||jp<'0')&&jp!='-')
        jp=getchar();
    if (jp=='-')
        fh=-1,jp=getchar();
    while (jp>='0'&&jp<='9')
        out=out*10+jp-'0',jp=getchar();
    return out*fh;
}
const int P=998244353,inv2=(P+1)>>1;
inline int add(int a,int b)
{
    return (a + b) % P;
}
inline int mul(int a,int b)
{
    return 1LL * a * b % P;
}
void FWT(int *a,int n,int op)
{
    for(int l=2;l<=n;l<<=1)
    {
        int m=l>>1;
        for(int *p=a;p!=a+n;p+=l)
            for(int i=0;i<m;++i)
            {
                int x0=p[i],x1=p[i+m];
                if(op==1)//or
                    p[i]=x0,p[i+m]=add(x0,x1);
                else if(op==2)//and
                    p[i]=add(x0,x1),p[i+m]=x1;
                else//xor
                    p[i]=add(x0,x1),p[i+m]=add(x0,P-x1);
            }
    }
}
void IFWT(int *a,int n,int op)
{
    for(int l=2;l<=n;l<<=1)
    {
        int m=l>>1;
        for(int *p=a;p!=a+n;p+=l)
            for(int i=0;i<m;++i)
            {
                int x0=p[i],x1=p[i+m];
                if(op==1)
                    p[i]=x0,p[i+m]=add(x1,P-x0);
                else if(op==2)
                    p[i]=add(x0,P-x1),p[i+m]=x1;
                else
                    p[i]=mul(add(x0,x1),inv2),p[i+m]=mul(add(x0,P-x1),inv2);
            }
    }
}
const int MAXN=(1<<17)+10;
int a[MAXN],b[MAXN],c[MAXN];
int main()
{
    int n=read();
    n=1<<n;
    for(int i=0;i<n;++i)
        a[i]=read();
    for(int i=0;i<n;++i)
        b[i]=read();
    for(int op=1;op<=3;++op)
    {
        FWT(a,n,op);
        FWT(b,n,op);
        for(int i=0;i<n;++i)
            c[i]=mul(a[i],b[i]);
        IFWT(a,n,op);
        IFWT(b,n,op);
        IFWT(c,n,op);
        for(int i=0;i<n;++i)
            printf("%d ",c[i]);
        puts("");
    }
    return 0;
}

三模數 \(NTT\)

  • 取三個具備 \(P=p\cdot 2^k+1\) 形式的模數,在每一個模意義下分別求解,最後用 \(CRT​\) 合併.這樣作要求最終答案不超過三個模數之積.

拆係數 \(FFT\)

  • 對任意模數 \(P\) ,記 \(m=\sqrt P\).那麼把每一個係數寫成 \(a\times m+b,a,b<m\) 的形式,而後乘法分配律分別作,這樣每一個係數都不會超過 \(P\) ,最後再合併.

分治 \(NTT/FFT\)

  • 如今要算這樣一個形式的卷積:
    \[ f_i=\sum_{j=1}^{i}f_{i-j}\cdot g_j \]c++

  • 給出初始值 \(f_0\) 以及 \(g\) 的全部值,求 \(f\) 的全部項.
  • 若是每次直接作 \(NTT/FFT\) ,複雜度顯然爲 \(O(n^2logn)\) ,爆炸.
  • 考慮一個像 \(cdq\) 分治的東西,即算左邊,再算左邊對右邊的貢獻,再算右邊,長度爲 \(1\) 時返回.
  • 若當前計算區間 \([l,r)\) ,就把 \(f\) 的前 \(\frac {r-l} 2\) 項與 \(g\) 的前 \(r-l\) 項作卷積,把答案的後 \(\frac {r-l} 2\) 項加入 \(f\) 的對應位置.
  • 這樣作後半段必定是對的,因此就不用關注長度的問題了.時間複雜度爲 \(O(nlog^2n)\) .git

    分治 \(NTT\)

#include<bits/stdc++.h>
using namespace std;
#define ll long long
inline ll read()
{
    ll out=0,fh=1;
    char jp=getchar();
    while ((jp>'9'||jp<'0')&&jp!='-')
        jp=getchar();
    if (jp=='-')
        fh=-1,jp=getchar();
    while (jp>='0'&&jp<='9')
        out=out*10+jp-'0',jp=getchar();
    return out*fh;
}
const int P=998244353,G=3;
inline int add(int a,int b)
{
    return (a + b) % P;
}
inline int mul(int a,int b)
{
    return 1LL * a * b % P;
}
int fpow(int a,int b)
{
    int res=1;
    while(b)
    {
        if(b&1)
            res=mul(res,a);
        a=mul(a,a);
        b>>=1;
    }
    return res;
}
const int MAXN=2e5+10;
int rev[MAXN],omega[MAXN],inv[MAXN];
void rev_init(int n,int lim)
{
    for(int i=0;i<n;++i)
    {
        int t=0;
        for(int j=0;j<lim;++j)
            if((i>>j)&1)
                t|=1<<(lim-j-1);
        rev[i]=t;
    }
}
void omega_init(int n)
{
    for(int l=2;l<=n;l<<=1)
    {
        omega[l]=fpow(G,(P-1)/l);
        inv[l]=fpow(omega[l],P-2);
    }
}
void NTT(int *a,int n,int lim,int invflag)
{
    for(int i=0;i<n;++i)
    {
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
    }
    for(int l=2;l<=n;l<<=1)
    {
        int m=l>>1;
        int gi=omega[l];
        if(invflag)
            gi=inv[l];
        for(int *p=a;p!=a+n;p+=l)
        {
            int g=1;
            for(int i=0;i<m;++i)
            {
                int t=mul(g,p[i+m]);
                p[i+m]=add(p[i],P-t);
                p[i]=add(p[i],t);
                g=mul(g,gi);
            }
        }
    }
    if (invflag)
    {
        int invn=fpow(n,P-2);
        for(int i=0;i<n;++i)
            a[i]=mul(a[i],invn);
    }
}
int f[MAXN],g[MAXN],a[MAXN],b[MAXN];
void solve(int l,int r,int lim)//[l,r)
{
    if(lim<=0)
        return;
    int mid=(l+r)>>1,len=r-l;
    solve(l,mid,lim-1);
    rev_init(len,lim);
    for(int i=0;i<len/2;++i)
        a[i]=f[l+i];
    for(int i=len/2;i<len;++i)
        a[i]=0;
    for(int i=0;i<len;++i)
        b[i]=g[i];
    NTT(a,len,lim,false);
    NTT(b,len,lim,false);
    for(int i=0;i<len;++i)
        a[i]=mul(a[i],b[i]);
    NTT(a,len,lim,true);
    for(int i=len/2;i<len;++i)
        f[i+l]=add(f[i+l],a[i]);
    solve(mid,r,lim-1);
}
int main()
{
    int n=read();
    f[0]=1;
    for(int i=1;i<n;++i)
        g[i]=read();
    int N=1,lim=0;
    while(N<n)
        N<<=1,++lim;
    omega_init(N);
    solve(0,N,lim); 
    for(int i=0;i<n;++i)
        printf("%d ",f[i]);
    puts("");
    return 0;
}
相關文章
相關標籤/搜索