拆係數FFT

學習內容:國家集訓隊2016論文 - 再談快速傅里葉變換html

模板題:http://uoj.ac/problem/34c++

1.基本介紹

對長度爲L的\(A(x),B(x)\)進行DFT,能夠利用web

\[ \begin{align} P(x)=A(x)+iB(x) \tag{1} \\ Q(x)=A(x)-iB(x) \tag{2} \end{align} \]算法

\(P(x)\)進行DFT,獲得\(F_p\)學習

\(Q(x)\)的結果 DFT\(F_q[k]=!(F_p[2L-k])\),(!表示取共軛)(證實見論文)。優化

spa

\[ \begin{align} DFT(A[k])=\frac{F_p[k]+F_q[k]} 2 \tag{3} \\ DFT(B[k])=-i\frac{F_p[k]-F_q[k]} 2 \tag{4} \end{align} \]code

這就是兩兩合併計算DFT的方法,2次DFT優化爲了1次。htm

IDFT的計算有兩種方法,一種是帶入\(-w_n^k\),另外一種是將序列[1..n-1]翻轉,再進行FFT,兩種方法結果都要除以n。get

//495ms
#include <bits/stdc++.h>
#define rep(i,l,r) for(int i=l,ed=r;i<ed;++i)
typedef long long ll;
const double PI = acos(-1);
const int N = 1<<20;
const int BUF_SIZE=33554431;
using namespace std;

struct buf{
    char a[BUF_SIZE],b[BUF_SIZE],*s,*t;
    buf():s(a),t(b){a[fread(a,1,sizeof a,stdin)]=0;}
    ~buf(){fwrite(b,1,t-b,stdout);}
    operator int(){
        int x=0;
        while(*s<48)++s;
        while(*s>32)
            x=x*10+*s++-48;
        return x;
    }
    void out(int x){
        static char c[12];
        char*i=c;
        if(!x)*t++=48;
        else{
            while(x){
                int y=x/10;
                *i++=x-y*10+48,x=y;
            }
            while(i!=c)*t++=*--i;
        }
        *t++=10;
    }
}it;
struct cp{
    double x,y;
    cp(double _x=0,double _y=0):x(_x),y(_y){}
    cp operator +(const cp&amp; b)const{return cp(x+b.x,y+b.y);}
    cp operator -(const cp&amp; b)const{return cp(x-b.x,y-b.y);}
    cp operator *(const cp&amp; b)const{return cp(x*b.x-y*b.y,x*b.y+y*b.x);}
    cp operator !()const{return cp(x,-y);}
}w[N];
void fft(cp p[],int n){
    for(int i=0,j=0;i<n;++i){
        if(i>j)swap(p[i],p[j]);
        for(int l=n>>1;(j^=l)<l;l>>=1);
    }
    for(int i=2;i<=n;i<<=1)
    for(int j=0,m=i>>1;j<n;j+=i)
        rep(k,0,m){
            cp b=w[n/i*k]*p[j+m+k];
            p[j+m+k]=p[j+k]-b;
            p[j+k]=p[j+k]+b;
        }
}
void conv(int n,ll *x,ll *y,ll *z){
    static cp p[N],q[N],h(0,-0.25);
    rep(i,0,n){
        w[i]=cp(cos(2*PI*i/n),sin(2*PI*i/n));
        p[i]=cp(x[i],y[i]);
    }
    fft(p,n);
    rep(i,0,n){
        int j=i?(n-i):0;
        q[j]=(p[i]*p[i]-!p[j]*!p[j])*h;
    }
    fft(q,n);
    rep(i,0,n)z[i]=q[i].x/n+0.5;
}
int n,m,p;
ll a[N],b[N],c[N];
int main(){
    n=it+1;m=it+1;
    rep(i,0,n) a[i]=it;
    rep(i,0,m) b[i]=it;
    for(n+=m-1,p=1;p<n;p<<=1);
    conv(p,a,b,c);
    rep(i,0,n)it.out(c[i]);
    return 0;
}

2.更快的卷積

\(A(x)\)表示爲\(A_0(x^2)+xA_1(x^2)\)\(、A_0(x^2)、xA_1(x^2)\)分別是偶次項、奇次項的和。

那麼

\[ \begin{align} A(x)B(x)&=(A_0(x^2)+xA_1(x^2))(B_0(x^2)+xB_1(x^2))\\ &=A_0(x^2)B_0(x^2)+x(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2))+x^2A_1(x^2)B_1(x^2) \end{align} \]

能夠分別對\(A_0(x)、A_1(x)、B_0(x)、B_1(x)\)計算DFT,而後再把上式\(x^0,x^1,x^2\)的係數算出來,再進行3次IDFT。共7次。

DFT能夠兩兩合併優化爲2次,且是兩次長度爲L(原來是2L)的DFT。

IDFT時也能夠兩兩合併,因而就須要2次長度L的IDFT。共4次。

若是這兩次IDFT還能夠兩兩合併,那就只要計算一次IDFT。共3次長度L的計算。

推導以下:

\(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2)\)的 IDFT 結果就是奇數項的係數。\(A_0(x^2)B_0(x^2)\)\(x^2A_1(x^2)B_1(x^2)\) 則是偶數項的係數。

\(A_0(x^2)B_0(x^2)\)\(x^2A_1(x^2)B_1(x^2)\)看作是關於\(x^2\)的多項式,能夠兩兩合併計算。令

\[ g=DFT(A_0)\cdot DFT(B_0)+w[k]DFT(A_1)\cdot DFT(B_1)\\ f=DFT(A_0)\cdot DFT(B_1)+DFT(A_1)\cdot DFT(B_0) \]

\(xA(x)\)就是\(w_n^k\cdot DFT(A)\)。咱們只要計算出\(IDFT(g)\)\(IDFT(f)\)便可。

若是 IDFT 的結果是實數,那麼兩個 IDFT 就能夠合併計算,令

\[ P(x)=g+i\cdot f \]

那麼

\[ IDFT(P(x))=IDFT(f)+i \cdot IDFT(g) \]

因而取實部和虛部分別做爲奇數和偶數項的係數便可。

\[ j=\begin{cases} 0& k=0\\ n-k& k\neq 0 \end{cases} \]

那麼

\[ \begin{aligned} g&=\frac {P_k+!P_j}{2}\cdot \frac {Q_k+!Q_j}{2}+w[k]\cdot \frac {P_k-!P_j}{-2i}\cdot \frac {Q_k-!Q_j}{-2i}\\ &=\frac 1 4 [(P_k+!P_j)\cdot(Q_k+!Q_j)-w[k]\cdot(P_k-!P_j)\cdot(Q_k-!Q_j)]\\ \\ f&=\frac {P_k+!P_j} 2 \cdot \frac{Q_k-!Q_j}{-2}i+\frac {Q_k+!Q_j} 2 \cdot \frac{P_k-!P_j}{-2}i\\ &=\frac i{-4}[2\cdot P_k\cdot Q_k-2\cdot !P_j\cdot !Q_j] \end{aligned} \]

因而

\[ \begin{aligned} g+f\cdot i&=\frac 1 4 [(P_k+!P_j)\cdot(Q_k+!Q_j)-w[k]\cdot(P_k-!P_j)\cdot(Q_k-!Q_j)-2\cdot P_k\cdot Q_k+2 !(P_j\cdot Q_j)]\\ &=\frac 1 4 [-(P_k-!P_j)\cdot(Q_k-!Q_j)+2\cdot (P_k\cdot Q_k+!(P_j\cdot Q_j))\\ &-w[k]\cdot(P_k-!P_j)\cdot(Q_k-!Q_j)+2\cdot P_k\cdot Q_k-2\cdot !(P_j\cdot Q_j)]\\ &=Q_k\cdot P_k-\frac 1 4[(1+w[k])\cdot (P_k-!P_j)\cdot(Q_k-!Q_j)]\\ \end{aligned} \]

//325ms
#include <bits/stdc++.h>
#define rep(i,l,r) for(int i=l,ed=r;i<ed;++i)
typedef long long ll;
const double PI = acos(-1);
const int N = 1<<20;
const int BUF_SIZE=33554431;
using namespace std;

struct buf{
    char a[BUF_SIZE],b[BUF_SIZE],*s,*t;
    buf():s(a),t(b){a[fread(a,1,sizeof a,stdin)]=0;}
    ~buf(){fwrite(b,1,t-b,stdout);}
    operator int(){
        int x=0;
        while(*s<48)++s;
        while(*s>32)
            x=x*10+*s++-48;
        return x;
    }
    void out(int x){
        static char c[12];
        char*i=c;
        if(!x)*t++=48;
        else{
            while(x){
                int y=x/10;
                *i++=x-y*10+48,x=y;
            }
            while(i!=c)*t++=*--i;
        }
        *t++=10;
    }
}it;
struct cp{
    double x,y;
    cp(double _x=0,double _y=0):x(_x),y(_y){}
    cp operator +(const cp&amp; b)const{return cp(x+b.x,y+b.y);}
    cp operator -(const cp&amp; b)const{return cp(x-b.x,y-b.y);}
    cp operator *(const cp&amp; b)const{return cp(x*b.x-y*b.y,x*b.y+y*b.x);}
    cp operator *(double b)const{return cp(b*x,b*y);}
    cp operator !()const{return cp(x,-y);}
}w[N];
void fft(cp *p,int n){
    for(int i=0,j=0;i<n;++i){
        if(i>j)swap(p[i],p[j]);
        for(int l=n>>1;(j^=l)<l;l>>=1);
    }
    for(int i=2;i<=n;i<<=1)
    for(int j=0,m=i>>1;j<n;j+=i)
        rep(k,0,m){
            cp b=w[n/i*k]*p[j+m+k];
            p[j+m+k]=p[j+k]-b;
            p[j+k]=p[j+k]+b;
        }
}
void conv(int n,ll *x,ll *y,ll *z){
    static cp p[N],q[N],a[N];
    rep(i,0,n){
        (i&amp;1?p[i>>1].y:p[i>>1].x)=x[i];
        (i&amp;1?q[i>>1].y:q[i>>1].x)=y[i];
    }
    rep(i,0,n>>=1)w[i]=cp(cos(2*PI*i/n),sin(2*PI*i/n));
    fft(p,n);fft(q,n);
    rep(i,0,n){
        int j=i?n-i:0;
        a[j]=p[i]*q[i]-((cp(1,0)+w[i])*(p[i]-!p[j])*(q[i]-!q[j]))*0.25;
    }
    fft(a,n);
    rep(i,0,n)z[i<<1]=a[i].x/n+0.5,z[i<<1|1]=a[i].y/n+0.5;
}
int n,m,p;
ll a[N],b[N],c[N];
int main(){
    n=it+1;m=it+1;
    rep(i,0,n) a[i]=it;
    rep(i,0,m) b[i]=it;
    for(n+=m-1,p=2;p<n;p<<=1);
    conv(p,a,b,c);
    rep(i,0,n)it.out(c[i]);
    return 0;
}

3.拆係數FFT

要計算任意模數的卷積,咱們通常考慮NTT+中國剩餘定理CRT。NTT中須要模數是質數且表示爲\(p=c\cdot 2^k+1\)\(2^k\)要不小於n。

考慮直接算出卷積不取模,那麼每一個數不會超過\(M^2n\)。假設模數\(M\)\(10^9\)級別,n是\(10^5\)級別,那麼結果都是\(10^{23}\)級別,咱們能夠找三個都是\(10^9\)級別知足NTT要求的模數,利用中國剩餘定理就能獲得在\(10^{27}\)級別的模數意義下的結果,再對\(M\)取模便可。

可是這樣常數就要乘3了。效率過低。拆係數FFT就是替代NTT解決模任意數且很是高效的算法。

若是利用FFT計算,浮點數會有偏差,int128是一個方法,可是不是全部場合都能使用。因此須要拆係數。

\(M_0=\lceil \sqrt M\rceil\),設

\[ a_i=k[a_i]M_0+b[a_i]\\ b_i=k[b_i]M_0+b[b_i] \]

其中\(k[a_i],b[a_i]< M_0\)

假設\(K_a(x)\)是以\(k[a_i]\)爲係數的多項式,\(B_a(x)\)是以\(b[a_i]\)爲係數的多項式,\(K_b(x),B_b(x)\)同理,則:

\[ A(x)=K_a(x)M_0+B_a(x)\\ B(x)=K_b(x)M_0+B_b(x)\\ A(x)B(x)=K_a(x)K_b(x)M_0^2+(K_a(x)B_b(x)+K_b(x)B_a(x))M_0+B_a(x)B_b(x) \]

和上面「更快的卷積」同樣分析,兩兩合併能夠將7次DFT及IDFT計算優化爲4次:

\(M_0\)能夠取一個超過\(\sqrt M\)的2的冪次,比較方便計算。

\[ P(x)=K_a(x)+iB_a(x)\\ Q(x)=K_b(x)+iB_b(x) \]

可知

\[ DFT(K_a[k])=\frac {F_p[k]+!(F_p[(n-k)\%n])} 2\\ DFT(B_a[k])=-i\frac {F_p[k]-!(F_p[(n-k)\%n])} 2\\ DFT(K_b[k])=\frac {F_q[k]-!(F_q[(n-k)\%n])} 2\\ DFT(B_b[k])=-i\frac {F_q[k]-!(F_q[(n-k)\%n])} 2\\ \]

因而只要計算出P(x)的DFT:\(F_p(x)\)和Q(x)的DFT:\(F_q(x)\),就能求出\(K_a(x),B_a(x),K_b(x),B_b(x)\)的DFT。

接下來IDFT的兩兩合併,以\(K_a(x)K_b(x)\)\(K_a(x)B_b(x)\)爲例,令

\[ dfta[k]=DFT(K_a[k])\cdot DFT(K_b[k])\\ dftb[k]=DFT(K_a[k])\cdot DFT(B_b[k]) \]

咱們須要對\(dfta(x)\)\(dftb(x)\)進行IDFT。注意到這裏IDFT的結果必定是實數,那麼令

\[ p[k]=dfta[k]+i\cdot dftb[k] \]

那麼 \(IDFT(p)\) 的實部除以n就是\(K_a(x)K_b(x)\),虛部除以n就是\(K_a(x)B_b(x)\)

因爲\(、k[x]、b[x]\)都是不超過\(2^{15}\)的數,因而就不容易被卡精度了。計算出來的結果再取模M就是答案了。

//933ms
#include <bits/stdc++.h>
#define rep(i,l,r) for(int i=l,ed=r;i<ed;++i)
typedef long long ll;
const double PI = acos(-1);
const int N = 1<<20;
const ll mod = 1e9+7;
const int BUF_SIZE=33554431;
using namespace std;

struct buf{
    char a[BUF_SIZE],b[BUF_SIZE],*s,*t;
    buf():s(a),t(b){a[fread(a,1,sizeof a,stdin)]=0;}
    ~buf(){fwrite(b,1,t-b,stdout);}
    operator int(){
        int x=0;
        while(*s<48)++s;
        while(*s>32)
            x=x*10+*s++-48;
        return x;
    }
    void out(int x){
        static char c[12];
        char*i=c;
        if(!x)*t++=48;
        else{
            while(x){
                int y=x/10;
                *i++=x-y*10+48,x=y;
            }
            while(i!=c)*t++=*--i;
        }
        *t++=10;
    }
}it;
struct cp{
    double x,y;
    cp(double _x=0,double _y=0):x(_x),y(_y){}
    cp operator +(const cp&amp; b)const{return cp(x+b.x,y+b.y);}
    cp operator -(const cp&amp; b)const{return cp(x-b.x,y-b.y);}
    cp operator *(const cp&amp; b)const{return cp(x*b.x-y*b.y,x*b.y+y*b.x);}
    cp operator !()const{return cp(x,-y);}
}w[N];
void fft(cp p[],int n){
    for(int i=0,j=0;i<n;++i){
        if(i>j)swap(p[i],p[j]);
        for(int l=n>>1;(j^=l)<l;l>>=1);
    }
    for(int i=2;i<=n;i<<=1)
    for(int j=0,m=i>>1;j<n;j+=i)
        rep(k,0,m){
            cp b=w[n/i*k]*p[j+m+k];
            p[j+m+k]=p[j+k]-b;
            p[j+k]=p[j+k]+b;
        }
}
void conv(int n,ll *x,ll *y,ll *z){
    static cp p[N],q[N],a[N],b[N],c[N],d[N];
    static cp r(0.5,0),h(0,-0.5),o(0,1);
    rep(i,0,n){
        w[i]=cp(cos(2*PI*i/n),sin(2*PI*i/n));
        x[i]=(x[i]+mod)%mod,y[i]=(y[i]+mod)%mod;
        p[i]=cp(x[i]>>15,x[i]&amp;32767),q[i]=cp(y[i]>>15,y[i]&amp;32767);
    }
    fft(p,n);fft(q,n);
    rep(i,0,n){
        int j=i?(n-i):0;
        static cp ka,ba,kb,bb;
        ka=(p[i]+!p[j])*r;
        ba=(p[i]-!p[j])*h;
        kb=(q[i]+!q[j])*r;
        bb=(q[i]-!q[j])*h;
        a[j]=ka*kb;b[j]=ka*bb;
        c[j]=kb*ba;d[j]=ba*bb;
    }
    rep(i,0,n){
        p[i]=a[i]+b[i]*o;
        q[i]=c[i]+d[i]*o;
    }
    fft(p,n);fft(q,n);
    rep(i,0,n){
        ll a,b,c,d;
        a=(ll)(p[i].x/n+0.5)%mod;
        b=(ll)(p[i].y/n+0.5)%mod;
        c=(ll)(q[i].x/n+0.5)%mod;
        d=(ll)(q[i].y/n+0.5)%mod;
        z[i]=((a<<30)+((b+c)<<15)+d)%mod;
    }
}
int n,m,p;
ll a[N],b[N],c[N];
int main(){
    n=it+1;m=it+1;
    rep(i,0,n) a[i]=it;
    rep(i,0,m) b[i]=it;
    for(n+=m-1,p=1;p<n;p<<=1);
    conv(p,a,b,c);
    rep(i,0,n)it.out((c[i]+mod)%mod);
    return 0;
}

題目:

待補充

相關文章
相關標籤/搜索