FFT什麼的

  這裏只有公式&作法,沒有複雜的證實(實際上是由於弱雞yww不會)數組

  參考自國家集訓隊論文&各個博客ide

多項式

​  一個以\(x\)爲變量的多項式定義在一個代數域\(F\)上,將函數\(A(x)\)表示爲形式和:
\[ A(x)=\sum_{j=0}^{n-1}a_jx^j \]
咱們稱\(a_0,a_1,\ldots,a_{n-1}\)爲多項式的係數,全部係數都屬於數域\(F\),典型的情形是負數集合\(C\)函數

  若是一個多項式的最高次的非零係數是\(a_k\),則稱\(A(x)\)的次數是\(k\)。任何嚴格大於一個多項式次數的整數都是該多項式的次數界。所以,對於次數界爲\(n\)的多項式\(C(x)\),其次數能夠是\(0\)~\(n-1\)之間的任何整數,包括\(0\)\(n-1\)優化

​  咱們在多項式上能夠定義不少不一樣的運算。ui

多項式加法

​  若是\(A(x)\)\(B(x)\)是次數界爲\(n\)的多項式,那麼他們的和也是一個次數界爲\(n\)的多項式\(C(x)\)。對於全部屬於定義域的\(x\),都有\(C(x)=A(x)+B(x)\)。也就是說,若
\[ A(x)=\sum_{j=0}^{n-1}a_jx^j\\ B(x)=\sum_{j=0}^{n-1}b_jx^j \]

\[ C(x)=\sum_{j=0}^{n-1}c_jx^j\\ \]
其中
\[ c_j=a_j+b_j \]
​  例如,若是
\[ A(x)=6x^3+7x^2-10x+9,B(x)=-2x^3+4x-5 \]

\[ C(x)=4x^3+7x^2-6x+4 \]spa

多項式乘法

​  若是\(A(x)\)是次數界爲\(n\)的多項式,\(B(x)\)是次數界爲\(m\)的多項式,那麼他們的乘積是一個次數界爲\(n+m\)的多項式\(C(x)\)。其中
\[ c_j=\sum_{k=0}^ja_kb_{j-k} \]
​  例如,若是
\[ A(x)=6x^3+7x^2-10x+9,B(x)=-2x^3+4x-5 \]
​  則
\[ C(x)=-12x^6-14x^5+44x^4-20x^3-75x^2+86x-45 \].net

多項式的表示

係數表達

​  對一個次數界爲\(n\)的多項式\(A(x)=\sum_{j=0}^{n-1}a_jx^j\)而言,其係數表達式一個由係數組成獲得向量\(a=(a_0,a_1,\cdots,a_{n-1})\)code

​  咱們能夠用秦久韶算法在\(O(n)\)的時間內求出多項式在給定點\(x_0\)的值,即求值運算:
\[ A(x_0)=a_0+x_0(a_1+a_0(a_2+\cdots+x_0(a_{n-1}+x_0(a_{n-1})\cdots)) \]
​  相似的,對於兩個分別用係數向量\(a=(a_0,a_1,\cdots,a_{n-1}),b=(b_0,b_1,\cdots,b_{n-1})\)表示的多項式進行相加時,所需的時間是\(O(n)\)。咱們只用輸出係數向量\(c=(c_0,c_1,\cdots,c_{n-1})\),其中\(c_i=a_i+b_i\)blog

​  如今來考慮兩個用係數形式表達的次數界爲\(n\)的多項式\(A(x),B(x)\)的乘法運算,所須要的時間是\(O(n^2)\)。係數向量\(c\)也稱爲輸入向量\(a,b\)的卷積。\(c=a\otimes b\)

點值表達

​  一個次數界爲\(n\)的多項式的點值表達就是一個有\(n\)個點值對所組成的集合。
\[ \{(x_0,y_0),(x_1,y_1),\cdots,(x_{n-1},y_{n-1})\} \]
使得對\(k=0,1,\cdots,n-1\),全部\(x_k\)各不相同且\(y_k=A(x_k)\)

​  一個多項式能夠有不少不一樣的點值表達,由於能夠採用\(n\)個不一樣的點構成的集合做爲這種表示方法的基。

​  樸素的求值是\(O(n^2)\)的。

​  求值的逆稱爲插值。當插值多項式的次數界等於已知的點值對的數目時,插值纔是明確的。

​  咱們能夠在用高斯消元在\(O(n^3)\)內插值,也能夠用拉格朗日插值\(O(n^2)\)內插值。

​  以上求值和插值能夠將多項式的係數表達和點值表達進行相互轉化,上面給出的算法的時間複雜度是\(O(n^2)\),但咱們能夠巧妙地選取\(x_k\)來加速這一過程,使其運行時間變爲\(O(nlogn)\)

​  對於許多多項式相關的操做,點值表達式很便利的。

​  對於加法,若是\(C(x)=A(x)+B(x)\)。給定\(A\)的點值表達
\[ \{(x_0,y_0),(x_1,y_1),\cdots,(x_{n-1},y_{n-1})\} \]
\(B\)的點值表達
\[ \{(x_0,y'_0),(x_1,y'_1),\cdots,(x_{n-1},y'_{n-1})\} \]
(注意,\(A\)\(B\)在相同的\(n\)個位置求值),則\(C\)的點值表達是
\[ \{(x_0,y_0+y'_0),(x_1,y_1+y'_1),\cdots,(x_{n-1},y_{n-1}+y'_{n-1})\} \]
所以,對兩個點值形式表示的次數界爲\(n\)的多項式相加,時間複雜度是\(O(n)\)

​  相似的,若是\(C(x)=A(x)B(x)\),咱們須要\(2n\)個點值對才能插出\(C\)。給定\(A\)的點值表達
\[ \{(x_0,y_0),(x_1,y_1),\cdots,(x_{2n-1},y_{2n-1})\} \]
\(B\)的點值表達
\[ \{(x_0,y'_0),(x_1,y'_1),\cdots,(x_{2n-1},y'_{2n-1})\} \]
(注意,\(A\)\(B\)在相同的\(2n\)個位置求值),則\(C\)的點值表達是
\[ \{(x_0,y_0y'_0),(x_1,y_1y'_1),\cdots,(x_{2n-1},y_{2n-1}y'_{2n-1})\} \]
所以,對兩個點值形式表示的次數界爲\(n\)的多項式相乘,時間複雜度是\(O(n)\)

​  最後,咱們考慮一個採用點值表達的多項式,如何求其在某個新點上的值。最簡單的方法是把該多項式轉成係數形式表達,而後在新點處求值。

係數形式表示的多項式的快速乘法

​  若是咱們選\(n\)次單位複數根做爲求值點,咱們能夠在\(O(nlogn)\)內求值和插值。咱們先在對這兩個多項式\(A,B\)求值以前添加\(n\)\(0\),使其次數界加倍爲\(2n\)。如今咱們採用「\(2n\)次單位複數根」做爲求值點。

DFT&FFT&IDFT

單位複數根

​  \(n\)次單位複數根是知足\(w^n=1\)的複數\(w\)\(n\)次單位複數根剛好有\(n\)個,對於\(k=0,1,\cdots,n-1\),這些根是\(e^{\frac{2\pi ik}{n}}\)\(w_n=e^\frac{2\pi i}{n}\)稱爲主\(n\)次單位根,全部其餘\(n\)次單位複數根都是\(w_n\)的冪次。這\(n\)\(n\)次單位複數根在乘法意義下造成了一個羣,即\(w_n^jw_n^k=w_n^{(j+k)mod~n}\),並且這\(n\)\(n\)次單位複數根均勻分佈在以複平面的原點爲圓心的單位半徑的圓周上。(圖片from zjt)

  

​  消去引理:對任何整數\(n\geq 0,k\geq 0,d>0\)
\[ w_{dn}^{dk}=w_n^k \]

DFT

​  回顧一下,咱們但願計算次數界爲\(n\)的多項式\(A(x)\)\(w_n^0,w_n^1,\cdots,w_n^{n-1}\)處的值(即在\(n\)\(n\)次單位複數根處)。對於\(k=0,1,\cdots,n-1\),定義結果\(y_k\)
\[ y_k=A(w_n^k)=\sum_{j=0}^{n-1}a_jw_n^{kj} \]
向量\(y=(y_0,y_1,\cdots,y_{n-1})\)就是係數向量\(a\)的離散傅里葉變換(DFT),咱們也記爲\(y=DFT_n(a)\)

FFT

​  利用單位複數根的特殊性質,咱們能夠在\(O(nlogn)\)內計算出\(DFT_n(a)\)。這裏假設\(n\)\(2\)的冪。

  FFT利用了分治策略。

  咱們令\(a=(a_0,a_1,\cdots,a_{n-1}),a_1=(a_0,a_2,\cdots,a_{n-2}),a_2=(a_1,a_3,\cdots,a_{n-1})\)

  對於\(k<\frac n2\)有:
\[ \begin{align} y_k&=A(w_n^k)\\ &=\sum_{j=0}^{n-1}a_jw_n^{kj}\\ &=\sum_{j=0}^{\frac n2-1}a_{2j}w_n^{2kj}+\sum_{j=0}^{\frac n2-1}a_{2j+1}w_n^{2kj+k}\\ &=\sum_{j=0}^{\frac n2-1}a_{2j}w_n^{2kj}+w_n^k\sum_{j=0}^{\frac n2-1}a_{2j+1}w_n^{2kj}\\ &=\sum_{j=0}^{\frac n2-1}{a_1}_{j}w_{\frac n2}^{kj}+w_n^k\sum_{j=0}^{\frac n2-1}{a_2}_{j}w_{\frac n2}^{kj}\\ &={y_1}_k+w_n^k{y_2}_k \end{align} \]
  對於\(k\geq \frac n2\)有:
\[ \begin{align} y_k&=A(w_n^k)\\ &=\sum_{j=0}^{n-1}a_jw_n^{kj}\\ &=\sum_{j=0}^{\frac n2-1}a_{2j}w_n^{2kj}+\sum_{j=0}^{\frac n2-1}a_{2j+1}w_n^{2kj+k}\\ &=\sum_{j=0}^{\frac n2-1}a_{2j}w_n^{2kj}+w_n^k\sum_{j=0}^{\frac n2-1}a_{2j+1}w_n^{2kj}\\ &=\sum_{j=0}^{\frac n2-1}{a_1}_{j}w_{\frac n2}^{kj}+w_n^k\sum_{j=0}^{\frac n2-1}{a_2}_{j}w_{\frac n2}^{kj}\\ &=\sum_{j=0}^{\frac n2-1}{a_1}_{j}w_{\frac n2}^{(k-\frac n2)j}+w_n^k\sum_{j=0}^{\frac n2-1}{a_2}_{j}w_{\frac n2}^{(k-\frac n2)j}\\ &={y_1}_{k-\frac n2}+w_n^k{y_2}_{k-\frac n2}\\ &={y_1}_{k-\frac n2}-w_n^{k-\frac n2}{y_2}_{k-\frac n2} \end{align} \]
  這樣咱們把\(y_1,y_2\)合併爲\(y\)的時間複雜度是\(O(n)\)。因此總的時間複雜度是
\[ T(n)=2T(\frac n2)+O(n)=O(n\log n) \]

IDFT

​  經過推導公式,咱們獲得:
\[ a_k=\frac1n\sum_{j=0}^{n-1}y_jw_n^{-kj} \]
​  因此咱們能夠用相似FFT的方法在\(O(n\log n)\)內求出\(IDFT_n(y)\)

多項式乘法

​  咱們能夠在\(O(n)\)內補\(0\)\(O(n\log n)\)內求值,\(O(n)\)內點值乘法,\(O(n\log n)\)內插值。因此咱們能夠在\(O(n\log n)\)內求出\(a\otimes b\)
\[ a\otimes b=IDFT_{2n}(DFT_{2n}(a)\cdot DFT_{2n}(b)) \]

蝶形運算

  咱們把由\({y_1}_k,{y_2}_k,w_n^k\)獲得\(y_k,y_{k+\frac n2}\)的過程稱爲蝴蝶操做。

​  咱們發現,遞歸時\(a\)是長這樣的:
\[ 0~~~1~~~2~~~3~~~4~~~5~~~6~~~7\\ 0~~~2~~~4~~~6~|~1~~~3~~~5~~~7\\ 0~~~4~|~2~~~6~|~1~~~5~|~3~~~7\\ 0~|~4~|~2~|~6~|~1~|~5~|~3~|~7 \]
  總的蝶形運算是長這樣的:
  
  

​  能夠發現,最後\(a_i\)是原來的\(a_{rev(i)}\)。因此咱們能夠交換\(a_i,a_{rev(i)}\),而後一層層來作。這樣能夠減少常數。

NTT

​  在某些時候,咱們須要求模\(p\)意義下的卷積。

​  先求出\(p\)的原根\(g\),能夠發現,\(g^{\frac{p-1}{n}}\)\(w_n\)的性質相似。因此咱們能夠用\(g^{\frac{p-1}{n}}\)來代替\(w_n\)

時間上的優化

  當咱們要算兩個多項式 \(A(x), B(x)\) 的乘積的時候,普通的作法是先把 \(a,b\) 兩個序列 DFT,再點乘,再 IDFT 回去。

  可是咱們還有一種方法:

​  令\(t_j=(a_j+b_j)+(a_j-b_j)i,S=T\times T\)

​  \(s_j\)的實部爲
\[ \begin{align} \sum_{k=0}^j(a_k+b_k)(a_{j-k}+b_{j-k})-(a_k-b_k)(a_{j-k}-b_{j-k})&=\sum_{k=0}^j4a_kb_{j-k}=4\sum_{k=0}^ja_kb_{j-k} \end{align} \]
  這樣咱們就能夠求出\(S=T\times T\),而後把\(s_j\)除以\(4\)

  這個方法能夠把\(3\)次DFT改爲\(2\)次DFT。

多項式求導

  給定\(A(x)=\sum_{i\geq 0}a_ix^i\),定義\(A(x)\)的形式導數爲
\[ A'(x)=\sum_{i\geq 1}ia_ix^{i-1} \]

多項式積分

  給定\(A(x)=\sum_{i\geq 0}a_ix^i\),則
\[ \int A(x)=\sum_{i\geq 1}\frac{a_{i-1}}{i}x^i \]

多項式求逆

​  多項式\(A(x)\)存在乘法逆元的充要條件是\(A(x)\)的常數項存在乘法逆元。

​  下面介紹一個\(O(n~log~n)\)計算乘法逆元的算法,它的本質是牛頓迭代法

​  首先求出\(A(x)\)常數項的逆元\(b\),令\(B(x)\)的初始值爲\(b\)

​  假設已求出知足
\[ A(x)B(x)\equiv1~(mod~x^n) \]
\(B(x)\),則
\[ \begin{align} A(x)B(x)-1&\equiv0~(mod~x^n)\\ {(A(x)B(x)-1)}^2&\equiv 0~(mod~x^{2n})\\ A(x)(2B(x)-B(x)^2A(x))&\equiv 1~(mod~x^{2n}) \end{align} \]
​  咱們能夠用\(O(n~log~n)\)的時間計算出\(2B(x)-B(x)^2A(x)\),並將它賦值給\(B(x)\)進行下一次迭代。每迭代一次,\(B(x)\)的有效項數\(n\)都會增長一倍。因而該算法的時間複雜度爲
\[ T(n)=T(n/2)+O(n\log n)=O(n\log n) \]

多項式開根

  已知\(A(x)\),求\(B(x)\)使得
\[ B(x)^2\equiv A(x)~(mod~x^n) \]

  先求出\(A(x)\)常數項的平方根\(b\)(能夠用二次剩餘的東西來算,但我只會暴力算),令\(B(x)\)的初始值爲\(b\)

  假設已求出知足
\[ B(x)^2\equiv A(x)~(mod~x^n) \]
\(B(x)\),則
\[ \begin{align} B(x)^2-A(x)&\equiv 0~(mod~x^n)\\ {(B(x)^2-A(x))}^2&\equiv 0~(mod~x^{2n})\\ B(x)^4-2B(x)^2A(x)+A(x)^2&\equiv 0~(mod~x^{2n})\\ B(x)^4+2B(x)^2A(x)+A(x)^2&\equiv 4B(x)^2A(x)~(mod~x^{2n})\\ {(B(x)^2+A(x))}^2&\equiv {(2B(x))}^2A(x)~(mod~x^{2n})\\ {(\frac{B(x)^2+A(x)}{2B(x)})}^2&\equiv A(x)~(mod~x^{2n}) \end{align} \]
  咱們能夠在\(O(n\log n)\)內算出\(\frac{B(x)^2+A(x)}{2B(x)}=\frac{B(x)}{2}+\frac{A(x)}{2B(x)}\),並把它賦值給\(B(x)\)

  時間複雜度:\(O(n\log n)\)

多項式ln

  給定形式冪級數\(A(x)=\sum_{i\geq 1}a_ix^i\),定義
\[ \ln(1-A(x))=-\sum_{i\geq 1}\frac{{A(x)}^i}{i} \]
  給定多項式\(A(x)=1+\sum_{i\geq 1}a_ix^i\),令
\[ B(x)=\ln(A(x)) \]

\[ B'(x)=\frac{A'(x)}{A(x)} \]
  只須要求出\(A(x)\)的乘法逆元,就能夠求出\(\ln(A(x))\)

多項式exp

  給定形式冪級數\(A(x)=\sum_{i\geq 1}a_ix^i\),定義
\[ \exp(A(x))=\sum_{i\geq 0}\frac{{A(x)}^i}{i!} \]
  令\(f(x)=e^{A(x)}\),可獲得一個關於\(f(x)\)的方程
\[ g(f(x))=\ln(f(x))-A(x)=0 \]
  考慮用牛頓迭代解這一方程。首先\(f(x)\)的常數項是容易肯定的(就是\(1\))。

  設以求得\(f(x)\)的前\(n\)\(f_0(x)\),即
\[ f(x)\equiv f_0(x)~~~(mod~~~x^n) \]
  做泰勒展開得
\[ \begin{align} 0&=g(f(x))\\ &=g(f_0(x))+g'(f_0(x))(f(x)-f_0(x))~~~~~(mod~~~x^{2n}) \end{align} \]

\[ f(x)\equiv f_0(x)-\frac{g(f_0(x))}{g'(f_0(x))}~~~~(mod~~~x^{2n}) \]
  把上面那個式子帶入得
\[ \begin{align} f(x)&=f_0(x)-\frac{\ln(f_0(x))-A(x)}{\frac{1}{f_0(x)}}\\ &=f_0(x)(1-\ln(f_0(x))+A(x)) \end{align} \]
  時間複雜度:\(O(n\log n)\)
  

多項式求冪

  給你\(A(x),k\),求\(A^k(x)\)

  設\(A(x)\)中最低次數項是\(cx^d\),那麼先把整個多項式除以\(cx^d\),再求\(\ln\),把整個多項式乘以\(k\),再求\(\exp\),再乘上\(c^kx^{kd}\)
\[ A^k(x)=\exp(k\ln\frac{A(x)}{cx^d}))c^kx^{kd} \]
  時間複雜度:\(O(n\log n)\)

多項式除法

​  給你\(A(x),B(x)\),求兩個多項式\(D(x),R(x)\)知足
\[ A(x)=D(x)B(x)+R(x) \]
​  若\(A(x)\)是一個\(n\)階多項式,則
\[ A^R(x)=x^nA(\frac1x) \]
  舉個例子:好比說
\[ A(x)=x^3+2x^2+3x+4\\ A^R(x)=1+2x+3x^2+4x^3 \]
​  至關於把\(A(x)\)的係數反轉。

  咱們設\(A(x)\)\(n\)階多項式,\(B(x)\)\(m\)階多項式,\(D(x)\)\(n-m\)階多項式,\(R(x)\)\(m-1\)階多項式。咱們把上個式子的\(x\)\(\frac1x\),而後所有乘上\(x^n\)
\[ x^nA(\frac1x)=x^{n-m}D(\frac1x)x^mB(\frac1x)+x^{n-m+1}x^{m-1}R(\frac1x)\\ A^R(x)=D^R(x)B^R(x)+x^{n-m+1}R^R(x) \]
  而後咱們把這個式子放在模\(x^{n-m+1}\)意義下,獲得
\[ A^R(x)=D^R(x)B^R(x)~(mod~x^{n-m+1})\\ D^R(x)=A^R(x){(B^R(x))}^{-1}~(mod~x^{n-m+1}) \]
  由於\(D(x)\)的次數是\(n-m\),因此不會受模意義的影響。

  而後把\(D(x)\)帶入到原來的式子中,就能夠算出\(R(x)\)了。

  時間複雜度:\(O(n\log n)\)

多點求值

  給你一個多項式\(A(x)\)\(n\)個點\(x_0,x_1,\cdots,x_{n-1}\),求這個多項式在這\(n\)個點處的值,即求\(A(x_0),A(x_1),\cdots,A(x_{n-1})\)

  考慮一個簡單的作法:構造\(B_i(x)=x-x_i,C_i(x)=A(x)~mod~B_i(x)\),那麼\(B_i(x_i)=0\)。因此\(A(x_i)=C_i(x_i)\)。可是計算\(B_i(x)\)\(C_i(x)\)\(O(n)\)的,必須加速這個過程。

  設當前求值的點爲\(X=\{x_0,x_1,\cdots,x_{n-1}\}\),咱們能夠把這\(n\)個點分爲兩半:
\[ X_0=\{x_0,x_1,\cdots,x_{\frac n2-1}\}\\ X_1=\{x_{\frac n2},x_{\frac n2+1},\cdots,x_{n-1}\} \]
  構造多項式
\[ B_0=\prod_{i=0}^{\frac n2-1}(x-x_i)\\ B_1=\prod_{i=\frac n2}^{n-1}(x-x_i)\\ A_0=A~mod~B_0\\ A_1=A~mod~B_1 \]
  那麼當\(x\in X_0\)\(A(x)=A_0(x)\),能夠遞歸計算。當\(x\in X_1\)時同理。

  每一層計算\(B_0,B_1,A_0,A_1\)的時間複雜度都是\(O(n\log n)\)

  總的時間複雜度就是
\[ T(n)=2T(\frac n2)+O(n\log n)=O(n\log^2n) \]

快速插值

  考慮怎麼求\(g_i=\prod_{j=0,j\neq i}^n (x_i-x_j)\),也就是分母。

\[ \begin{align} g_i&=\prod_{j=0,j\neq i}^n (x_i-x_j)\\ &=\lim_{x \to x_i}\frac{\prod_{j=0}^n (x-x_j)}{x-x_i}\\ &=(\prod_{j=0}^n (x-x_j))'|_{x=x_i} \end{align} \]

  能夠分治求出\(\prod_{j=0}^n (x-x_j)\)再求導後在全部\(x_i\)處多點求值。

  分子直接分治求出。

  時間複雜度:\(O(n\log^2n)\)

小技巧1

  好比咱們要計算兩個實數序列的卷積\(A\times B=C\),記\(D_i=(a_i+b_i)+(a_i-b_i)i\),那麼\(C_i=\frac{1}{4}real({D^2}_i)\)
  
  這樣就能夠把三次DFT減小到兩次DFT。
  
  固然,若是\(A=B\)那麼這個優化是沒有效果的。

任意模數FFT

模板

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
    if(a>b)
        swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
    char str[100];
    sprintf(str,"%s.in",s);
    freopen(str,"r",stdin);
    sprintf(str,"%s.out",s);
    freopen(str,"w",stdout);
#endif
}
int rd()
{
    int s=0,c;
    while((c=getchar())<'0'||c>'9');
    do
    {
        s=s*10+c-'0';
    }
    while((c=getchar())>='0'&&c<='9');
    return s;
}
int upmin(int &a,int b)
{
    if(b<a)
    {
        a=b;
        return 1;
    }
    return 0;
}
int upmax(int &a,int b)
{
    if(b>a)
    {
        a=b;
        return 1;
    }
    return 0;
}
const ll p=998244353;
const ll g=3;
ll fp(ll a,ll b)
{
    ll s=1;
    while(b)
    {
        if(b&1)
            s=s*a%p;
        a=a*a%p;
        b>>=1;
    }
    return s;
}
const int maxn=600000;
ll inv[maxn];
namespace ntt
{
    ll w1[maxn];
    ll w2[maxn];
    int rev[maxn];
    int n;
    void init(int m)
    {
        n=1;
        while(n<m)
            n<<=1;
        int i;
        for(i=2;i<=n;i<<=1)
        {
            w1[i]=fp(g,(p-1)/i);
            w2[i]=fp(w1[i],p-2);
        }
        rev[0]=0;
        for(i=1;i<n;i++)
            rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
    }
    void ntt(ll *a,int t)
    {
        int i,j,k;
        ll u,v,w,wn;
        for(i=0;i<n;i++)
            if(rev[i]<i)
                swap(a[i],a[rev[i]]);
        for(i=2;i<=n;i<<=1)
        {
            wn=(t==1?w1[i]:w2[i]);
            for(j=0;j<n;j+=i)
            {
                w=1;
                for(k=j;k<j+i/2;k++)
                {
                    u=a[k];
                    v=a[k+i/2]*w%p;
                    a[k]=(u+v)%p;
                    a[k+i/2]=(u-v)%p;
                    w=w*wn%p;
                }
            }
        }
        if(t==-1)
        {
            u=fp(n,p-2);    
            for(i=0;i<n;i++)
                a[i]=a[i]*u%p;
        }
    }
    ll x[maxn];
    ll y[maxn];
    ll z[maxn];
    void copy_clear(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m;i++)
            a[i]=b[i];
        for(i=m;i<n;i++)
            a[i]=0;
    }
    void copy(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m;i++)
            a[i]=b[i];
    }
    void mul(ll *a,ll *b,ll *c,int m)
    {
        init(m<<1);
        copy_clear(x,a,m);
        copy_clear(y,b,m);
        ntt(x,1);
        ntt(y,1);
        int i;
        for(i=0;i<n;i++)
            x[i]=x[i]*y[i]%p;
        ntt(x,-1);
        copy(c,x,m);
    }
    void inverse(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            b[0]=fp(a[0],p-2);
            return;
        }
        inverse(a,b,m>>1);
        init(m<<1);
        copy_clear(x,a,m);
        copy_clear(y,b,m>>1);
        ntt(x,1);
        ntt(y,1);
        int i;
        for(i=0;i<n;i++)
            x[i]=y[i]*(2-x[i]*y[i]%p)%p;
        ntt(x,-1);
        copy(b,x,m);
    }
    ll c[maxn],d[maxn],e[maxn],f[maxn];
    void sqrt(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            if(a[0]==1)
                b[0]=1;
            else if(a[0]==0)
                b[0]=0;
            else
                //我也不會
                ;
            return;
        }
        sqrt(a,b,m>>1);
//      copy_clear(c,b,m>>1);
        int i;
        for(i=m;i<m<<1;i++)
            b[i]=0;
        inverse(b,d,m);
        init(m<<1);
        for(i=m;i<m<<1;i++)
            b[i]=d[i]=0;
        ll inv2=fp(2,p-2);
        copy_clear(x,a,m);
        ntt(x,1);
        ntt(d,1);
        for(i=0;i<n;i++)
            x[i]=x[i]*d[i]%p;
        ntt(x,-1);
        for(i=0;i<m;i++)
            b[i]=((b[i]+x[i])%p*inv2)%p;
    }
    void derivative(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m-1;i++)
            b[i]=(i+1)*a[i+1]%p;
        b[m-1]=0;
    }
    void differential(ll *a,ll *b,int m)
    {
        int i;
        for(i=m-1;i>=1;i--)
            b[i]=a[i-1]*inv[i]%p;
        b[0]=0;
    }
    void ln(ll *a,ll *b,int m)
    {
        static ll c[maxn],d[maxn];
        derivative(a,c,m);
        inverse(a,d,m);
        init(m<<1);
        int i;
        for(i=m;i<n;i++)
            c[i]=d[i]=0;
        ntt(c,1);
        ntt(d,1);
        for(i=0;i<n;i++)
            c[i]=c[i]*d[i]%p;
        ntt(c,-1);
        differential(c,b,m);
    }
    void exp(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            b[0]=1;
            return;
        }
        exp(a,b,m>>1);
        int i;
        for(i=m>>1;i<m;i++)
            b[i]=0;
        ln(b,y,m);
        init(m<<1);
        copy_clear(x,a,m);
        x[0]++;
        for(i=0;i<m;i++)
            x[i]=(x[i]-y[i])%p;
        copy_clear(y,b,m);
        ntt(x,1);
        ntt(y,1);
        for(i=0;i<n;i++)
            x[i]=x[i]*y[i]%p;
        ntt(x,-1);
        copy(b,x,m);
    }
    void module(ll *a,ll *b,ll *c,int n1,int n2)
    {
        int k=1;
        while(k<=n1-n2+1)
            k<<=1;
        int i;
        for(i=0;i<=n1;i++)
            d[i]=a[i];
        for(i=0;i<=n2;i++)
            e[i]=b[i];
        reverse(d,d+n1+1);
        reverse(e,e+n2+1);
        for(i=n1-n2+1;i<k<<1;i++)
            d[i]=e[i]=0;
        inverse(e,f,k);
        for(i=n1-n2+1;i<k<<1;i++)
            f[i]=0;
        init(k<<1);
        ntt::ntt(d,1);
        ntt::ntt(f,1);
        for(i=0;i<n;i++)
            e[i]=d[i]*f[i]%p;
        ntt::ntt(e,-1);
        for(i=0;i<=n1-n2;i++)
            c[i]=e[i];
        reverse(c,c+n1-n2+1);
    }
};
ll b[maxn];
ll a[maxn];
ll c[maxn];
void get(ll *a,int n)
{
    int i;
    for(i=0;i<n;i++)
        a[i]=rand();
}
int main()
{
//  freopen("fft.txt","w",stdout);
//  srand(time(0));
//  int n=262144;
//  int bg,ed;
//  int i;
//  int times=100,j;
//  double s,s1;
//  inv[0]=inv[1]=1;
//  for(i=2;i<=n;i++)
//      inv[i]=-(p/i)*inv[p%i]%p;
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      bg=clock();
//      ntt::init(n);
//      ntt::ntt(a,1);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("ntt :%.10lf\n",s/times);
//  s1=s;
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      get(b,n);
//      bg=clock();
//      ntt::mul(a,b,c,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("mul :%.10lf %.10lf\n",s/times,s/s1);
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      bg=clock();
//      ntt::inverse(a,b,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("inv :%.10lf %.10lf\n",s/times,s/s1);
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      a[0]=1;
//      bg=clock();
//      ntt::sqrt(a,b,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("sqrt:%.10lf %.10lf\n",s/times,s/s1);
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      a[0]=1;
//      bg=clock();
//      ntt::ln(a,b,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("ln  :%.10lf %.10lf\n",s/times,s/s1);
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      bg=clock();
//      ntt::exp(a,b,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("exp :%.10lf %.10lf\n",s/times,s/s1);
//  return 0;
}

多點求值+快速插值

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const ll p=998244353;
const ll g=3;
const int maxw=131072;
const int maxn=150000;
ll fp(ll a,ll b)
{
    ll s=1;
    for(;b;b>>=1,a=a*a%p)
        if(b&1)
            s=s*a%p;
    return s;
}
int rt,cnt,ls[1000010],rs[1000010];
ll vx[100010],vy[100010],va[100010];
ll inv[maxn],w1[maxn],w2[maxn];
int rev[maxn];
void init()
{
    inv[0]=inv[1]=1;
    for(int i=2;i<=maxw;i++)
        inv[i]=-p/i*inv[p%i]%p;
    for(int i=2;i<=maxw;i<<=1)
    {
        w1[i]=fp(g,(p-1)/i);
        w2[i]=fp(w1[i],p-2);
    }
}
ll *f[1000010];
int len[maxn];
void clear(ll *a,int n)
{
    memset(a,0,(sizeof a[0])*n);
}
void ntt(ll *a,int n,int t)
{
    for(int i=1;i<n;i++)
    {
        rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
        if(i>rev[i])
            swap(a[i],a[rev[i]]);
    }
    for(int i=2;i<=n;i<<=1)
    {
        ll wn=(t==1?w1[i]:w2[i]);
        for(int j=0;j<n;j+=i)
        {
            ll w=1;
            for(int k=j;k<j+i/2;k++)
            {
                ll u=a[k];
                ll v=a[k+i/2]*w%p;
                a[k]=(u+v)%p;
                a[k+i/2]=(u-v)%p;
                w=w*wn%p;
            }
        }
    }
    if(t==-1)
    {
        ll inv=fp(n,p-2);
        for(int i=0;i<n;i++)
            a[i]=a[i]*inv%p;
    }
}
void mul(ll *a,ll *b,ll *c,int n,int m)
{
    int k=1;
    while(k<=n+m)
        k<<=1;
    static ll a1[maxn],a2[maxn];
    clear(a1,k);
    clear(a2,k);
    for(int i=0;i<=n;i++)
        a1[i]=a[i];
    for(int i=0;i<=m;i++)
        a2[i]=b[i];
    ntt(a1,k,1);
    ntt(a2,k,1);
    for(int i=0;i<k;i++)
        a1[i]=a1[i]*a2[i]%p;
    ntt(a1,k,-1);
    for(int i=0;i<=n+m;i++)
        c[i]=a1[i];
}
void getinv(ll *a,ll *b,int n)
{
    if(n==1)
    {
        b[0]=fp(a[0],p-2);
        return;
    }
    getinv(a,b,n>>1);
    static ll a1[maxn],a2[maxn];
    clear(a1,n<<1);
    clear(a2,n<<1);
    for(int i=0;i<n;i++)
        a1[i]=a[i];
    for(int i=0;i<n>>1;i++)
        a2[i]=b[i];
    ntt(a1,n<<1,1);
    ntt(a2,n<<1,1);
    for(int i=0;i<n<<1;i++)
        a1[i]=a2[i]*(2-a2[i]*a1[i]%p)%p;
    ntt(a1,n<<1,-1);
    for(int i=0;i<n;i++)
        b[i]=a1[i];
}
void div(ll *a,ll *b,ll *c,int n,int m)
{
    static ll a1[maxn],a2[maxn],a3[maxn];
    int k=1;
    while(k<=2*(n-m))
        k<<=1;
    for(int i=0;i<=n;i++)
        a1[i]=a[i];
    for(int i=0;i<=m;i++)
        a2[i]=b[i];
    reverse(a1,a1+n+1);
    reverse(a2,a2+m+1);
    clear(a1+n-m+1,k-(n-m+1));
    clear(a2+n-m+1,k-(n-m+1));
    getinv(a2,a3,k);
    clear(a3+n-m+1,k-(n-m+1));
    ntt(a1,k,1);
    ntt(a3,k,1);
    for(int i=0;i<k;i++)
        a1[i]=a1[i]*a3[i]%p;
    ntt(a1,k,-1);
    for(int i=0;i<=n-m;i++)
        c[i]=a1[i];
    reverse(c,c+n-m+1);
}
void getmod(ll *a,ll *b,ll *c,int n,int m)
{
    static ll a1[maxn],a2[maxn];
    int k=1;
    while(k<=n)
        k<<=1;
    clear(a1,k);
    clear(a2,k);
    for(int i=0;i<=m;i++)
        a1[i]=b[i];
    div(a,b,a2,n,m);
    ntt(a1,k,1);
    ntt(a2,k,1);
    for(int i=0;i<k;i++)
        a1[i]=a1[i]*a2[i]%p;
    ntt(a1,k,-1);
    for(int i=0;i<m;i++)
        c[i]=(a[i]-a1[i])%p;
}
void divide(int l,int r,int &now)
{
    now=++cnt;
    len[now]=r-l+1;
    f[now]=new ll[len[now]+1];
    if(l==r)
    {
        f[now][1]=1;
        f[now][0]=-vx[l];
        return;
    }
    int mid=(l+r)>>1;
    divide(l,mid,ls[now]);
    divide(mid+1,r,rs[now]);
    mul(f[ls[now]],f[rs[now]],f[now],len[ls[now]],len[rs[now]]);
}
void getv(ll *a,int n,int l,int r,int now)
{
    ll *a1=new ll[len[now]];
    getmod(a,f[now],a1,n,len[now]);
    if(l==r)
    {
        va[l]=a1[0];
        return;
    }
    int mid=(l+r)>>1;
    getv(a1,len[now]-1,l,mid,ls[now]);
    getv(a1,len[now]-1,mid+1,r,rs[now]);
}
ll *s[1000010];
void getpoly(int l,int r,int now)
{
    s[now]=new ll[len[now]];
    if(l==r)
    {
        s[now][0]=va[l];
        return;
    }
    int mid=(l+r)>>1;
    getpoly(l,mid,ls[now]);
    getpoly(mid+1,r,rs[now]);
    int k=1;
    while(k<=len[now])
        k<<=1;
    static ll a1[maxn],a2[maxn],a3[maxn],a4[maxn];
    clear(a1,k);
    clear(a2,k);
    clear(a3,k);
    clear(a4,k);
    for(int i=0;i<len[ls[now]];i++)
        a1[i]=s[ls[now]][i];
    for(int i=0;i<=len[rs[now]];i++)
        a2[i]=f[rs[now]][i];
    for(int i=0;i<len[rs[now]];i++)
        a3[i]=s[rs[now]][i];
    for(int i=0;i<=len[ls[now]];i++)
        a4[i]=f[ls[now]][i];
    ntt(a1,k,1);
    ntt(a2,k,1);
    ntt(a3,k,1);
    ntt(a4,k,1);
    for(int i=0;i<k;i++)
        a1[i]=(a1[i]*a2[i]+a3[i]*a4[i])%p;
    ntt(a1,k,-1);
    for(int i=0;i<len[now];i++)
        s[now][i]=a1[i];
}
int n;
ll a[maxn],b[maxn],c[maxn];
int main()
{
    init();
    scanf("%d",&n);
    for(int i=0;i<=n;i++)
        scanf("%lld%lld",&vx[i],&vy[i]);
    divide(0,n,rt);
    for(int i=0;i<=n;i++)
        a[i]=f[rt][i+1]*(i+1)%p;
    getv(a,n,0,n,rt);
//  for(int i=0;i<=n;i++)
//      printf("%lld ",(va[i]+p)%p);
//  printf("\n");
    for(int i=0;i<=n;i++)
        va[i]=fp(va[i],p-2)*vy[i]%p;
    getpoly(0,n,rt);
    for(int i=0;i<=n;i++)
        printf("%lld ",(s[rt][i]+p)%p);
    printf("\n");
    return 0;
}
相關文章
相關標籤/搜索