LG5577 算力訓練 k進制FWT

題意說人話就是給出一個長度爲$n$的數列$a_1,a_2,...,a_n$,求$\prod\limits_{i=1}^n (1+x^{a_i})$,其中卷積的下標加法定義爲$k$進制不進位加法。c++


$k$進制不進位加法不難想到$k$進制FWT,因此咱們須要快速求出$\prod\limits_{i=1}^n \mathrm{DWT}_k(1 + x^{a_i})$,這裏的乘法是點積,最後IDWT回來便可。git

由於要用形式冪級數作到乘單位根,因此咱們得到了一個$O(nk^mmk^2)$的優秀算法,然而並無什麼×用。算法


考慮優化。能夠發現任意一個$\mathrm{DWT}_k(1 + x^{a_i})$中全部數構成的集合都是集合$W = {w_k^i+1|i \in [0,k-1] \cap Z}$的子集,而$|W|=k$並不大,因此咱們能夠考慮設數組

$$x_{i,j} = \sum\limits_{p=1}^n [[x^i]\mathrm{DWT}_k(1+x^{a_p}) = w_k^j + 1]$$函數

那麼這$n$個冪級數的卷積的第$i$項的值就是優化

$$\prod\limits_{p=0}^{k-1} (w_k^p+1)^{x_{i,p}}$$spa

能夠快速計算。code


接下來考慮對於一個肯定的$i$如何求出$x_{i,j}$的值。毫無疑問須要解方程。get

首先咱們顯然有一個式子是it

$$x_{i,0} + x_{i,1} + x_{i,2}+...+x_{i,k-1}=n$$

設冪級數$A$知足$[x^j]A = \sum\limits_{i=1}^n [a_i = j]$,考慮$[x^i] \mathrm{DWT}_k(A)$。對於一個$p$和一個知足$[x^i]\mathrm{DWT}_k(1+x^{a_j}) = w_k^p + 1$的$j$,能夠發現其對$[x^i]\mathrm{DWT}_k(A)$的貢獻是$w_k^p$。因此咱們有

$$w_k^0x_{i,0}+w_k^1x_{i,1}+...+w_k^{k-1}x_{i,k-1} = [x^i]\mathrm{DWT}_k(A)$$


觀察兩個方程的係數向量:

$$1,1,1,1,...,1$$

$$w_k^0,w_k^1,w_k^2,...,w_k^{k-1}$$

有些範德蒙德矩陣的Feeling。那麼咱們能不能求出

$$w_k^0x_{i,0}+w_k^2x_{i,1}+...+w_k^{2(k-1)}x_{i,k-1}$$

也就是原來對某個位置貢獻爲$w_k^p$的數組,在變換以後它的貢獻變爲$w^{2p}_k$。

能夠發現這至關於將FWT過程當中的單位根平方一下。因此咱們只須要把全部$w_k$都變爲$w_k^2$就能夠了。具體來講,以$k=5$爲例,$k$進制FWT使用下面的位矩陣:

$$\left( \begin{array}{cccc} 1 & 1 & 1 & 1 & 1 \ 1 & w_5^1 & w_5^2 & w_5^3 & w_5^4 \ 1 & w_5^2 & w_5^4 & w_5^1 & w_5^3 \ 1 & w_5^3 & w_5^1 & w_5^4 & w_5^2 \ 1 & w_5^4 & w_5^3 &w_5^2 & w_5^1 \end{array} \right)$$

如今把單位根平方,位矩陣就變成下面這樣:

$$\left( \begin{array}{cccc} 1 & 1 & 1 & 1 & 1 \ 1 & w_5^2 & w_5^4 & w_5^1 & w_5^3 \ 1 & w_5^4 & w_5^3 & w_5^2 & w_5^1 \ 1 & w_5^1 & w_5^2 & w_5^3 & w_5^4 \ 1 & w_5^3 & w_5^1 &w_5^4 & w_5^2 \end{array} \right)$$

使用這一個位矩陣進行DWT,則$[x^i]\mathrm{DWT}k(A)=w_k^0x{i,0}+w_k^2x_{i,1}+...+w_k^{2(k-1)}x_{i,k-1}$。

值得注意的是咱們只是求值,因此這個矩陣就算不是合法的位矩陣也能夠這麼作(畢竟你不須要進行逆操做)。

照葫蘆畫瓢地能夠獲得$k$個方程,其係數矩陣是DWT使用的範德蒙德矩陣。因此只要對於獲得的全部結果IDWT一下就能夠獲得全部$x_{i,j}$的值。


樸素實現複雜度大概是$O(k^{m+4}m)$的($k$次FWT,每一次$k^mmk$次乘法,乘法複雜度$k^2$)比較慢。下面是一些實現細節:

  1. 能夠發現將單位根乘方以後對冪級數進行FWT獲得的每一位的值構成的集合必定是對其進行$k$進制FWT獲得的值的集合的子集。能夠實現一個函數計算對單位根進行乘方後FWT獲得的每一位的值分別對應進行$k$進制FWT後哪一位的值。這樣能夠把一個$k$摘掉,複雜度變爲$O(k^{m+3}m)$。這一部分實現能夠參考代碼中的getid函數。

  2. 單位根在模意義下不存在因此要擴域,即將一個數表示爲$a_0w^0+a_1w^1+...+a_{k-1}w^{k-1}$,能夠發現它是封閉的。若是直接這樣作有一個很是大的好處是全部單位根都只有一個位置有值,能夠作到$O(k)$乘單位根。有一個bug是最後的答案並非$a_0$因此並無這樣實現。

  3. 延續點2中的問題,咱們最後的答案不是$a_0$的緣由是有一些單位根它們的和爲$0$,好比說$\sum\limits_{i=0}^{k-1} w_k^{k-1}=0$,或者在$2 \mid k$時$w_k^i = w_k^{i+\frac{k}{2}}$,這意味着一個整數在這個域上的表示不是惟一的。這就是爲何我寫成了二合一:

  • 對於$k=5$的狀況將$w^4 = -(w^0+w^1+w^2+w^3)$代入,將一個數表示爲$a_0w^0+a_1w^1+a_2w^2+a_3w^3$的形式進行求解。
  • 對於$k=6$的狀況先使用$w_1=-w_4,w_3=-w_0,w_5=-w_2$,這樣就只剩下$w_0,w_2,w_4$,而後代入$w_4=-w_0-w_2$,這樣咱們能夠只用將數表示爲$a_0w_0+a_1w_2$的形式就能夠求解了。

以這樣的形式求解最後獲得的$a_0$就是答案。

然而我很想知道爲何這麼消了以後一個整數就必定能被惟一表示……


code:

#include<bits/stdc++.h>
using namespace std;

int read(){
	int a = 0; char c = getchar(); while(!isdigit(c)) c = getchar();
	while(isdigit(c)){a = a * 10 + c - 48; c = getchar();} return a;
}

const int MOD = 998244353;
int upd(int x){return x + (x >> 31 & MOD);}
int add(int x , int y){return upd(x + y - MOD);}
int sub(int x , int y){return upd(x - y);}
int mul(int x , int y){return x <= 1 || y <= 1 ? x * y : 1ll * x * y % MOD;}

int N , K , M , arr[1000003] , pwK[10]; long long IVK;
int id[100003];
void getid(int L , int pw){
	if(L == 1) return (void)(id[0] = 0);
	int p = L / K; getid(p , pw);
	for(int i = 0 ; i < p ; ++i){
		int t = id[i];
		for(int j = 0 ; j < K ; ++j , t = (t + p * pw) % L)
			id[i + j * p] = t;
	}
}

template < typename op >
void FWT(op *now , op *tmp , op *w , int tp){
	for(int i = 0 ; i < M ; ++i){
		for(int j = 0 ; j < pwK[M] ; j += pwK[i + 1])
			for(int k = 0 ; k < pwK[i] ; ++k){
				for(int l = 0 ; l < K ; ++l) tmp[j + k + pwK[i] * l] = op();
				for(int l = 0 ; l < K ; ++l)
					for(int p = 0 ; p < K ; ++p)
						tmp[j + k + pwK[i] * l] = tmp[j + k + pwK[i] * l] +
							(!l || !p ? now[j + k + pwK[i] * p] : now[j + k + pwK[i] * p] * w[(tp == 1 ? l : K - l) * p % K]);
			}
		for(int j = 0 ; j < pwK[M] ; ++j) now[j] = tmp[j];
	}
}

namespace solve1{
	struct op{
		int arr[4];
		op(int _a = 0 , int _b = 0 , int _c = 0 , int _d = 0){arr[0] = _a; arr[1] = _b; arr[2] = _c; arr[3] = _d;}
		int& operator [](int x){return arr[x];}
		friend op operator +(op x , op y){op t; for(int i = 0 ; i < 4 ; ++i) t[i] = add(x[i] , y[i]); return t;}
		friend op operator -(op x , op y){op t; for(int i = 0 ; i < 4 ; ++i) t[i] = sub(x[i] , y[i]); return t;}
		friend op operator *(op x , op y){
			op t; int sum4 = 0;
			for(int j = 0 ; j < 4 ; ++j)
				if(y[j])
					for(int i = 0 ; i < 4 ; ++i){
						int id = i + j >= 5 ? i + j - 5 : i + j;
						(id == 4 ? sum4 : t[id]) = add(id == 4 ? sum4 : t[id] , mul(x[i] , y[j]));
					}
			if(sum4) for(int i = 0 ; i < 4 ; ++i) t[i] = sub(t[i] , sum4);
			return t;
		}
	}now[100003] , tmp[100003] , val[5][100003] , w[5]{op(1),op(0,1),op(0,0,1),op(0,0,0,1),op(MOD-1,MOD-1,MOD-1,MOD-1)} , pww1[5][1003] , pww2[5][1003];

	void init_pww(){
		for(int i = 0 ; i < K ; ++i){
			pww1[i][0] = pww2[i][0] = op(1); pww1[i][1] = w[i] + op(1);
			for(int j = 2 ; j <= 1000 ; ++j) pww1[i][j] = pww1[i][j - 1] * pww1[i][1];
			pww2[i][1] = pww1[i][1000];
			for(int j = 2 ; j <= 1000 ; ++j) pww2[i][j] = pww2[i][j - 1] * pww2[i][1];
		}
	}

	op getpw(int id , int val){return pww2[id][val / 1000] * pww1[id][val % 1000];}
	
	void work(){
		init_pww();
		for(int i = 1 ; i <= N ; ++i){
			int tmp = arr[i] , t = 0 , tms = 1; while(tmp){t += tmp % 10 * tms; tms *= K; tmp /= 10;} ++now[t][0];
		}
		FWT(now , tmp , w , 1); for(int i = 0 ; i < pwK[M] ; ++i) val[0][i][0] = N;
		for(int i = 1 ; i < K ; ++i){getid(pwK[M] , i); for(int j = 0 ; j < pwK[M] ; ++j) val[i][j] = now[id[j]];}
		for(int i = 0 ; i < pwK[M] ; ++i){
			for(int j = 0 ; j < K ; ++j) tmp[j] = op(0 , 0);
			for(int j = 0 ; j < K ; ++j) for(int l = 0 ; l < K ; ++l) tmp[j] = tmp[j] + val[l][i] * w[l * (K - j) % K];
			now[i] = op(1); for(int j = 0 ; j < K ; ++j) now[i] = now[i] * getpw(j , 1ll * tmp[j][0] * IVK % MOD);
		}
		FWT(now , tmp , w , -1); int iv = 1; for(int i = 0 ; i < M ; ++i) iv = 1ll * iv * IVK % MOD;
		for(int i = 0 ; i < pwK[M] ; ++i) printf("%lld\n" , 1ll * now[i][0] * iv % MOD);
	}
}

namespace solve2{
	struct op{
		int x , y; op(int _x = 0 , int _y = 0) : x(_x) , y(_y){}
		friend op operator +(op x , op y){return op(add(x.x , y.x) , add(x.y , y.y));}
		friend op operator -(op x , op y){return op(sub(x.x , y.x) , sub(x.y , y.y));}
		friend op operator *(op x , op y){int t = mul(x.y , y.y); return op(sub(mul(x.x , y.x) , t) , sub(add(mul(x.y , y.x) , mul(x.x , y.y)) , t));}
	}now[100003] , tmp[100003] , val[6][100003] , w[6]{op(1),op(1,1),op(0,1),op(MOD-1),op(MOD-1,MOD-1),op(0,MOD-1)} , pww1[6][1003] , pww2[6][1003];

	void init_pww(){
		for(int i = 0 ; i < K ; ++i){
			pww1[i][0] = pww2[i][0] = op(1); pww1[i][1] = w[i] + op(1);
			for(int j = 2 ; j <= 1000 ; ++j) pww1[i][j] = pww1[i][j - 1] * pww1[i][1];
			pww2[i][1] = pww1[i][1000];
			for(int j = 2 ; j <= 1000 ; ++j) pww2[i][j] = pww2[i][j - 1] * pww2[i][1];
		}
	}

	op getpw(int id , int val){return pww2[id][val / 1000] * pww1[id][val % 1000];}
	
	void work(){
		init_pww();
		for(int i = 1 ; i <= N ; ++i){
			int tmp = arr[i] , t = 0 , tms = 1;
			while(tmp){t += tmp % 10 * tms; tms *= K; tmp /= 10;}
			++now[t].x;
		}
		FWT(now , tmp , w , 1); for(int i = 0 ; i < pwK[M] ; ++i) val[0][i].x = N;
		for(int i = 1 ; i < K ; ++i){getid(pwK[M] , i); for(int j = 0 ; j < pwK[M] ; ++j) val[i][j] = now[id[j]];}
		for(int i = 0 ; i < pwK[M] ; ++i){
			for(int j = 0 ; j < K ; ++j) tmp[j] = op(0 , 0);
			for(int j = 0 ; j < K ; ++j) for(int l = 0 ; l < K ; ++l) tmp[j] = tmp[j] + val[l][i] * w[l * (K - j) % K];
			now[i] = op(1 , 0); for(int j = 0 ; j < K ; ++j) now[i] = now[i] * getpw(j , 1ll * tmp[j].x * IVK % MOD);
		}
		FWT(now , tmp , w , -1); int iv = 1; for(int i = 0 ; i < M ; ++i) iv = 1ll * iv * IVK % MOD;
		for(int i = 0 ; i < pwK[M] ; ++i) printf("%lld\n" , 1ll * now[i].x * iv % MOD);
	}
}

int main(){
	N = read(); K = read(); M = read(); for(int i = 1 ; i <= N ; ++i) arr[i] = read();
	IVK = MOD + 1; while(IVK % K) IVK += MOD;
	IVK /= K; pwK[0] = 1; for(int i = 1 ; i <= M ; ++i) pwK[i] = pwK[i - 1] * K;
	if(K == 6) solve2::work(); else solve1::work();
	return 0;
}
相關文章
相關標籤/搜索