FFT解決字符串匹配問題

FFT解決字符串匹配問題

樸素的字符串匹配

設有兩個字符串 \(a,b\) 長度分別爲 \(n,m(n\geq m)\) ,詢問字符串 \(b\)\(a\) 中出現了幾回。ios

這是一個最基礎的字符串匹配問題,顯然咱們容易想到 \(O(nm)\) 的暴力匹配方法,同時也可使用 \(KMP\) 算法在 \(O(n+m)\) 的複雜度下解決該問題。實際上,咱們還能夠利用 \(FFT\) 來解決這一經典問題,時間複雜度爲 \(O(n\log n)\)c++

構造兩個多項式 \(A(x)=a_0+a_1x^1+a_2x^2+\cdots + a_{n-1}x^{n-1}\)\(B(x)=b_0+b_1x^1+b_2x^2+\cdots + b_{m-1}x^{m-1}\) ,其中 \(a_i\) 表示字符串 \(a\) 中下標爲 \(i\) 的字符對應的數字(例如咱們能夠將 \(a\) 映射爲 \(1\)\(b\) 映射爲 \(2\) ……)。而後咱們定義一個匹配多項式 \(C(x)\) ,該函數的第 \(j\)\(c_jx^j\) 的係數 \(c_j\)\(c_j = \sum_{i=0}^{m-1}(a_{j+i}-b_{i})^2\) 。咱們發現當這個式子的值爲零時,字符串 \(a\) 的子串 \(a_ja_{j+1}\cdots a_{j+m-1}\) 與字符串 \(b\) 匹配成功。算法

而後咱們考慮如何計算這個匹配多項式 \(C(x)\) ,咱們須要將其係數轉換爲卷積形式,所以咱們翻轉字符串 \(b\) ,設字符串 \(r=reverse(b)\) ,即 \(r_i = b_{m-i-1}\) 。因而:函數

\[\begin{aligned} c_j &= \sum_{i=0}^{m-1}(a_{j+i}-b_{i})^2\\ &= \sum_{i=0}^{m-1}(a_{j+i}-r_{m-i-1})^2 \\ &= \sum_{i=0}^{m-1}a_{j+i}^2+r_{m-i-1}^2-2a_{j+i}r_{m-i-1} \\ &= \sum_{i=0}^{m-1}a_{j+i}^2 + \sum_{i=0}^{m-1}r_{m-i-1}^2 - 2\sum_{i=0}^{m-1}a_{j+i}r_{m-i-1} \end{aligned} \]

所以 \(c_j\) 被咱們轉化成了 \(3\) 個和式,咱們依次分析。首先 \(\sum_{i=0}^{m-1}a_{j+i}^2\) 這一項顯然咱們能夠 \(O(n)\) 預處理前綴和,而後 \(O(1)\) 查詢;而後 \(\sum_{i=0}^{m-1}r_{m-i-1}^2\) 這一項是一個定值;最後 \(\sum_{i=0}^{m-1}a_{j+i}r_{m-i-1}\) 是一個卷積式,咱們能夠直接構造多項式 \(P(x)=A(x)*B(x)\) ,那麼就可以獲得該多項式的係數 \(p_j = \sum_{i=0}^ja_ib_{j-i}\) ,進行簡單的下標偏移後便可求出 \(c_j\) ,而後經過 \(FFT\) 加速計算,就能在 \(O(n\log n)\) 的複雜度下計算這個卷積。ui

下面給出一個 \(NTT\) 實現:this

#include <bits/stdc++.h>
using namespace std;

constexpr int mod = 998244353;
std::vector<int> rev, roots{ 0, 1 };
int powmod(int a, long long b) {
	int res = 1;
	for (; b; b >>= 1, a = 1ll * a * a % mod)
		if (b & 1)
			res = 1ll * res * a % mod;
	return res;
}
void dft(std::vector<int>& a) {
	int n = a.size();
	if (int(rev.size()) != n) {
		int k = __builtin_ctz(n) - 1;
		rev.resize(n);
		for (int i = 0; i < n; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
	}
	for (int i = 0; i < n; ++i)
		if (rev[i] < i)
			std::swap(a[i], a[rev[i]]);
	if (int(roots.size()) < n) {
		int k = __builtin_ctz(roots.size());
		roots.resize(n);
		while ((1 << k) < n) {
			int e = powmod(3, (mod - 1) >> (k + 1));
			for (int i = 1 << (k - 1); i < (1 << k); ++i) {
				roots[2 * i] = roots[i];
				roots[2 * i + 1] = 1ll * roots[i] * e % mod;
			}
			++k;
		}
	}
	for (int k = 1; k < n; k *= 2) {
		for (int i = 0; i < n; i += 2 * k) {
			for (int j = 0; j < k; ++j) {
				int u = a[i + j];
				int v = 1ll * a[i + j + k] * roots[k + j] % mod;
				int x = u + v;
				if (x >= mod)
					x -= mod;
				a[i + j] = x;
				x = u - v;
				if (x < 0)
					x += mod;
				a[i + j + k] = x;
			}
		}
	}
}
void idft(std::vector<int>& a) {
	int n = a.size();
	std::reverse(a.begin() + 1, a.end());
	dft(a);
	int inv = powmod(n, mod - 2);
	for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * inv % mod;
}
struct Poly {
	std::vector<int> a;
	Poly() {}
	Poly(int a0) {
		if (a0)
			a = { a0 };
	}
	Poly(const std::vector<int>& a1) : a(a1) {
		while (!a.empty() && !a.back()) a.pop_back();
	}
	int size() const { return a.size(); }
	int operator[](int idx) const {
		if (idx < 0 || idx >= size())
			return 0;
		return a[idx];
	}
	Poly mulxk(int k) const {
		auto b = a;
		b.insert(b.begin(), k, 0);
		return Poly(b);
	}
	Poly modxk(int k) const {
		k = std::min(k, size());
		return Poly(std::vector<int>(a.begin(), a.begin() + k));
	}
	Poly divxk(int k) const {
		if (size() <= k)
			return Poly();
		return Poly(std::vector<int>(a.begin() + k, a.end()));
	}
	friend Poly operator+(const Poly a, const Poly& b) {
		std::vector<int> res(std::max(a.size(), b.size()));
		for (int i = 0; i < int(res.size()); ++i) {
			res[i] = a[i] + b[i];
			if (res[i] >= mod)
				res[i] -= mod;
		}
		return Poly(res);
	}
	friend Poly operator-(const Poly a, const Poly& b) {
		std::vector<int> res(std::max(a.size(), b.size()));
		for (int i = 0; i < int(res.size()); ++i) {
			res[i] = a[i] - b[i];
			if (res[i] < 0)
				res[i] += mod;
		}
		return Poly(res);
	}
	friend Poly operator*(Poly a, Poly b) {
		int sz = 1, tot = a.size() + b.size() - 1;
		while (sz < tot) sz *= 2;
		a.a.resize(sz);
		b.a.resize(sz);
		dft(a.a);
		dft(b.a);
		for (int i = 0; i < sz; ++i) a.a[i] = 1ll * a[i] * b[i] % mod;
		idft(a.a);
		return Poly(a.a);
	}
	Poly& operator+=(Poly b) { return (*this) = (*this) + b; }
	Poly& operator-=(Poly b) { return (*this) = (*this) - b; }
	Poly& operator*=(Poly b) { return (*this) = (*this) * b; }
	Poly deriv() const {  // 求導
		if (a.empty())
			return Poly();
		std::vector<int> res(size() - 1);
		for (int i = 0; i < size() - 1; ++i) res[i] = 1ll * (i + 1) * a[i + 1] % mod;
		return Poly(res);
	}
	Poly integr() const {  // 積分
		if (a.empty())
			return Poly();
		std::vector<int> res(size() + 1);
		for (int i = 0; i < size(); ++i) res[i + 1] = 1ll * a[i] * powmod(i + 1, mod - 2) % mod;
		return Poly(res);
	}
	Poly inv(int m) const {  // 逆
		Poly x(powmod(a[0], mod - 2));
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (2 - modxk(k) * x)).modxk(k);
		}
		return x.modxk(m);
	}
	Poly log(int m) const { return (deriv() * inv(m)).integr().modxk(m); }
	Poly exp(int m) const {
		Poly x(1);
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (1 - x.log(k) + modxk(k))).modxk(k);
		}
		return x.modxk(m);
	}
	Poly mulT(Poly b) const {  // 卷積
		if (b.size() == 0)
			return Poly();
		int n = b.size();
		std::reverse(b.a.begin(), b.a.end());
		return ((*this) * b).divxk(n - 1);
	}
	std::vector<int> eval(std::vector<int> x) const {  // 求值
		if (size() == 0)
			return std::vector<int>(x.size(), 0);
		const int n = std::max(int(x.size()), size());
		std::vector<Poly> q(4 * n);
		std::vector<int> ans(x.size());
		x.resize(n);
		std::function<void(int, int, int)> build = [&](int p, int l, int r) {
			if (r - l == 1) {
				q[p] = std::vector<int>{ 1, (mod - x[l]) % mod };
			}
			else {
				int m = (l + r) / 2;
				build(2 * p, l, m);
				build(2 * p + 1, m, r);
				q[p] = q[2 * p] * q[2 * p + 1];
			}
		};
		build(1, 0, n);
		std::function<void(int, int, int, const Poly&)> work = [&](int p, int l, int r, const Poly& num) {
			if (r - l == 1) {
				if (l < int(ans.size()))
					ans[l] = num[0];
			}
			else {
				int m = (l + r) / 2;
				work(2 * p, l, m, num.mulT(q[2 * p + 1]).modxk(m - l));
				work(2 * p + 1, m, r, num.mulT(q[2 * p]).modxk(r - m));
			}
		};
		work(1, 0, n, mulT(q[1].inv(n)));
		return ans;
	}
};

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	cout.tie(nullptr);
	string s1, s2;
	cin >> s1 >> s2;
	reverse(s2.begin(), s2.end());
	int n = s1.length();
	int m = s2.length();
	vector<int> va(n);
	vector<int> vb(m);
	vector<int> pre(n);
	for (int i = 0; i < n; i++)
		va[i] = s1[i] - 'a' + 1;
	for (int i = 0; i < m; i++)
		vb[i] = s2[i] - 'a' + 1;
	int b = 0;
	pre[0] = va[0];
	for (int i = 1; i < n; i++)
		pre[i] = pre[i - 1] + va[i] * va[i];
	for (int i = 0; i < m; i++)
		b += vb[i] * vb[i];
	Poly pa(va);
	Poly pb(vb);
	pa = pa * pb;

	vector<int> res;
	for (int i = 0; i <= n - m; i++) {
		int val = pre[i + m - 1] - (i == 0 ? 0 : pre[i - 1]) + b;
		val -= 2 * pa[i + m - 1];
		if (val == 0)
			res.push_back(i + 1);
	}
	cout << (int)res.size() << '\n'; // 數量
	for (int i = 0; i < (int)res.size(); i++)
		cout << res[i] << " \n"[i == (int)res.size() - 1]; // 匹配位置
	return 0;
}

帶通配符的字符串匹配

使用 \(FFT\) 求解字符串匹配問題雖然時間複雜度與 \(KMP\) 算法差距較大,可是具備更強的擴展性,例以下面這題:spa

設有兩個字符串 \(a,b\) 長度分別爲 \(n,m(n\geq m)\) ,詢問字符串 \(b\)\(a\) 中出現了幾回。而且這兩個字符串中均含有通配符,能夠與任意字符完成匹配。code

題目來源:洛谷 \(P4173\) 殘缺的字符串 https://www.luogu.com.cn/problem/P4173ci

咱們考慮到,通配符能夠匹配任意一個字符,所以若是咱們繼續採用上方給出的樸素匹配算法,那麼通配符這一位的計算結果必然始終爲零,所以構造一個新的匹配多項式 \(C(x)\) ,第 \(j\) 項的係數 \(c_j = \sum_{i=0}^{m-1}(a_{j+i}-b_{i})^2a_{j+i}b_i\) ;而且預處理多項式時,通配符的對應數值爲零。一樣地,咱們翻轉字符串 \(b\)\(r\) ,而後展開:字符串

\[\begin{aligned} c_j &= \sum_{i=0}^{m-1}(a_{j+i}-b_{i})^2a_{j+i}b_i \\ &= \sum_{i=0}^{m-1}(a_{j+i}-r_{m-i-1})^2a_{j+i}r_{m-i-1} \\ &= \sum_{i=0}^{m-1}a_{j+i}^3r_{m-i-1} - 2a_{j+i}^2r_{m-i-1}^2+a_{j+i}r_{m-i-1}^3 \end{aligned} \]

顯然,這是 \(3\) 個卷積式,進行 \(3\) 次卷積便可解決這一問題。

#include <bits/stdc++.h>
using namespace std;

constexpr int mod = 998244353;
std::vector<int> rev, roots{ 0, 1 };
int powmod(int a, long long b) {
	int res = 1;
	for (; b; b >>= 1, a = 1ll * a * a % mod)
		if (b & 1)
			res = 1ll * res * a % mod;
	return res;
}
void dft(std::vector<int>& a) {
	int n = a.size();
	if (int(rev.size()) != n) {
		int k = __builtin_ctz(n) - 1;
		rev.resize(n);
		for (int i = 0; i < n; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
	}
	for (int i = 0; i < n; ++i)
		if (rev[i] < i)
			std::swap(a[i], a[rev[i]]);
	if (int(roots.size()) < n) {
		int k = __builtin_ctz(roots.size());
		roots.resize(n);
		while ((1 << k) < n) {
			int e = powmod(3, (mod - 1) >> (k + 1));
			for (int i = 1 << (k - 1); i < (1 << k); ++i) {
				roots[2 * i] = roots[i];
				roots[2 * i + 1] = 1ll * roots[i] * e % mod;
			}
			++k;
		}
	}
	for (int k = 1; k < n; k *= 2) {
		for (int i = 0; i < n; i += 2 * k) {
			for (int j = 0; j < k; ++j) {
				int u = a[i + j];
				int v = 1ll * a[i + j + k] * roots[k + j] % mod;
				int x = u + v;
				if (x >= mod)
					x -= mod;
				a[i + j] = x;
				x = u - v;
				if (x < 0)
					x += mod;
				a[i + j + k] = x;
			}
		}
	}
}
void idft(std::vector<int>& a) {
	int n = a.size();
	std::reverse(a.begin() + 1, a.end());
	dft(a);
	int inv = powmod(n, mod - 2);
	for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * inv % mod;
}
struct Poly {
	std::vector<int> a;
	Poly() {}
	Poly(int a0) {
		if (a0)
			a = { a0 };
	}
	Poly(const std::vector<int>& a1) : a(a1) {
		while (!a.empty() && !a.back()) a.pop_back();
	}
	int size() const { return a.size(); }
	int operator[](int idx) const {
		if (idx < 0 || idx >= size())
			return 0;
		return a[idx];
	}
	Poly mulxk(int k) const {
		auto b = a;
		b.insert(b.begin(), k, 0);
		return Poly(b);
	}
	Poly modxk(int k) const {
		k = std::min(k, size());
		return Poly(std::vector<int>(a.begin(), a.begin() + k));
	}
	Poly divxk(int k) const {
		if (size() <= k)
			return Poly();
		return Poly(std::vector<int>(a.begin() + k, a.end()));
	}
	friend Poly operator+(const Poly a, const Poly& b) {
		std::vector<int> res(std::max(a.size(), b.size()));
		for (int i = 0; i < int(res.size()); ++i) {
			res[i] = a[i] + b[i];
			if (res[i] >= mod)
				res[i] -= mod;
		}
		return Poly(res);
	}
	friend Poly operator-(const Poly a, const Poly& b) {
		std::vector<int> res(std::max(a.size(), b.size()));
		for (int i = 0; i < int(res.size()); ++i) {
			res[i] = a[i] - b[i];
			if (res[i] < 0)
				res[i] += mod;
		}
		return Poly(res);
	}
	friend Poly operator*(Poly a, Poly b) {
		int sz = 1, tot = a.size() + b.size() - 1;
		while (sz < tot) sz *= 2;
		a.a.resize(sz);
		b.a.resize(sz);
		dft(a.a);
		dft(b.a);
		for (int i = 0; i < sz; ++i) a.a[i] = 1ll * a[i] * b[i] % mod;
		idft(a.a);
		return Poly(a.a);
	}
	Poly& operator+=(Poly b) { return (*this) = (*this) + b; }
	Poly& operator-=(Poly b) { return (*this) = (*this) - b; }
	Poly& operator*=(Poly b) { return (*this) = (*this) * b; }
	Poly deriv() const {  // 求導
		if (a.empty())
			return Poly();
		std::vector<int> res(size() - 1);
		for (int i = 0; i < size() - 1; ++i) res[i] = 1ll * (i + 1) * a[i + 1] % mod;
		return Poly(res);
	}
	Poly integr() const {  // 積分
		if (a.empty())
			return Poly();
		std::vector<int> res(size() + 1);
		for (int i = 0; i < size(); ++i) res[i + 1] = 1ll * a[i] * powmod(i + 1, mod - 2) % mod;
		return Poly(res);
	}
	Poly inv(int m) const {  // 逆
		Poly x(powmod(a[0], mod - 2));
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (2 - modxk(k) * x)).modxk(k);
		}
		return x.modxk(m);
	}
	Poly log(int m) const { return (deriv() * inv(m)).integr().modxk(m); }
	Poly exp(int m) const {
		Poly x(1);
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (1 - x.log(k) + modxk(k))).modxk(k);
		}
		return x.modxk(m);
	}
	Poly mulT(Poly b) const {  // 卷積
		if (b.size() == 0)
			return Poly();
		int n = b.size();
		std::reverse(b.a.begin(), b.a.end());
		return ((*this) * b).divxk(n - 1);
	}
	std::vector<int> eval(std::vector<int> x) const {  // 求值
		if (size() == 0)
			return std::vector<int>(x.size(), 0);
		const int n = std::max(int(x.size()), size());
		std::vector<Poly> q(4 * n);
		std::vector<int> ans(x.size());
		x.resize(n);
		std::function<void(int, int, int)> build = [&](int p, int l, int r) {
			if (r - l == 1) {
				q[p] = std::vector<int>{ 1, (mod - x[l]) % mod };
			}
			else {
				int m = (l + r) / 2;
				build(2 * p, l, m);
				build(2 * p + 1, m, r);
				q[p] = q[2 * p] * q[2 * p + 1];
			}
		};
		build(1, 0, n);
		std::function<void(int, int, int, const Poly&)> work = [&](int p, int l, int r, const Poly& num) {
			if (r - l == 1) {
				if (l < int(ans.size()))
					ans[l] = num[0];
			}
			else {
				int m = (l + r) / 2;
				work(2 * p, l, m, num.mulT(q[2 * p + 1]).modxk(m - l));
				work(2 * p + 1, m, r, num.mulT(q[2 * p]).modxk(r - m));
			}
		};
		work(1, 0, n, mulT(q[1].inv(n)));
		return ans;
	}
};

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	cout.tie(nullptr);
	int n, m;
	string s1, s2;
	cin >> m >> n >> s2 >> s1;
	reverse(s2.begin(), s2.end());
	vector<int> va1(n);
	vector<int> va2(n);
	vector<int> va3(n);
	vector<int> vb1(m);
	vector<int> vb2(m);
	vector<int> vb3(m);
	for (int i = 0; i < n; i++) {
		if (s1[i] == '*') {
			va1[i] = va2[i] = va3[i] = 0;
		} else {
			va1[i] = s1[i] - 'a' + 1;
			va2[i] = (s1[i] - 'a' + 1) * va1[i];
			va3[i] = (s1[i] - 'a' + 1) * va2[i];
		}
	}
	for (int i = 0; i < m; i++) {
		if (s2[i] == '*') {
			vb1[i] = vb2[i] = vb3[i] = 0;
		} else {
			vb1[i] = s2[i] - 'a' + 1;
			vb2[i] = (s2[i] - 'a' + 1) * vb1[i];
			vb3[i] = (s2[i] - 'a' + 1) * vb2[i];
		}
	}
	vector<int> ans(n - m + 1, 0);
	vector<int> res;
	Poly pa, pb;
	pa = Poly(va1), pb = Poly(vb3);
	pa *= pb;
	for (int i = 0; i <= n - m; i++)
		ans[i] += pa[i + m - 1];
	pa = Poly(va3), pb = Poly(vb1);
	pa *= pb;
	for (int i = 0; i <= n - m; i++)
		ans[i] += pa[i + m - 1];
	pa = Poly(va2), pb = Poly(vb2);
	pa *= pb;
	for (int i = 0; i <= n - m; i++)
		ans[i] -= 2 * pa[i + m - 1];
	for (int i = 0; i <= n - m; i++) {
		if (ans[i] == 0) {
			res.push_back(i + 1);
		}
	}
	cout << (int)res.size() << '\n';
	for (int i = 0; i < (int)res.size(); i++)
		cout << res[i] << " \n"[i == (int)res.size() - 1];
	return 0;
}

字符集較小時的字符串匹配

設有兩個字符串 \(a,b\) 長度分別爲 \(n,m(n\geq m)\) ,詢問字符串 \(b\)\(a\) 中出現了幾回。而且,這兩個字符串均由字符集 \(S\) 構成。

當這個字符集 \(S\) 很小時,咱們利用 \(FFT\) 解決匹配問題又會多出一種新的方法。咱們換一個角度從新思考字符串匹配問題,構造匹配多項式 \(C(x)\) ,該多項式的第 \(j\) 項係數 \(c_j = \sum^m_{i=0}[a_{i+j}=b_i]\)

那麼,兩個字符串 \(a,b\) 在位置 \(j\) 成功匹配的充要條件就是 \(c_j=m\) 。可是這個式子的計算在字符較多時比較困難,由於難以控制布爾運算式的狀態。可是若是字符集較小時,咱們就能用 \(0,1\) 兩種狀態來實現這個思路。不妨假設這個字符集僅含有兩個字母 \(A,B\) ,咱們經過這個例子來介紹一種利用 \(01\) 多項式的方法。

思路很是簡單,咱們首先將兩個字符串中的字符 \(A\) 置爲 \(1\)\(B\) 置爲 \(0\) ,計算出匹配多項式 \(c_j = \sum^m_{i=0}[a_{i+j}=b_i]\) 每一項的值。因爲只有 \(0,1\) 兩種狀態,所以 \(c_j=\sum^m_{i=0}a_{i+j}b_i\) ,將 \(b\) 翻轉後便可卷積,此處略過推導。

而後將兩個字符串中的字符 \(B\) 置爲 \(1\) ,字符 \(A\) 置爲 \(0\) ,再進行卷積,並將結果加到匹配多項式的係數 \(c_j\) 上。最後遍歷一次多項式 \(C\) ,若 \(c_j=m\) 則成功匹配。所以咱們能夠經過枚舉字符集中的每個元素實現這一思路,時間複雜度 \(O(|S|n\log n)\)

下面一樣給出一個實現:

#include <bits/stdc++.h>
using namespace std;

constexpr int mod = 998244353;
std::vector<int> rev, roots{ 0, 1 };
int powmod(int a, long long b) {
	int res = 1;
	for (; b; b >>= 1, a = 1ll * a * a % mod)
		if (b & 1)
			res = 1ll * res * a % mod;
	return res;
}
void dft(std::vector<int>& a) {
	int n = a.size();
	if (int(rev.size()) != n) {
		int k = __builtin_ctz(n) - 1;
		rev.resize(n);
		for (int i = 0; i < n; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
	}
	for (int i = 0; i < n; ++i)
		if (rev[i] < i)
			std::swap(a[i], a[rev[i]]);
	if (int(roots.size()) < n) {
		int k = __builtin_ctz(roots.size());
		roots.resize(n);
		while ((1 << k) < n) {
			int e = powmod(3, (mod - 1) >> (k + 1));
			for (int i = 1 << (k - 1); i < (1 << k); ++i) {
				roots[2 * i] = roots[i];
				roots[2 * i + 1] = 1ll * roots[i] * e % mod;
			}
			++k;
		}
	}
	for (int k = 1; k < n; k *= 2) {
		for (int i = 0; i < n; i += 2 * k) {
			for (int j = 0; j < k; ++j) {
				int u = a[i + j];
				int v = 1ll * a[i + j + k] * roots[k + j] % mod;
				int x = u + v;
				if (x >= mod)
					x -= mod;
				a[i + j] = x;
				x = u - v;
				if (x < 0)
					x += mod;
				a[i + j + k] = x;
			}
		}
	}
}
void idft(std::vector<int>& a) {
	int n = a.size();
	std::reverse(a.begin() + 1, a.end());
	dft(a);
	int inv = powmod(n, mod - 2);
	for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * inv % mod;
}
struct Poly {
	std::vector<int> a;
	Poly() {}
	Poly(int a0) {
		if (a0)
			a = { a0 };
	}
	Poly(const std::vector<int>& a1) : a(a1) {
		while (!a.empty() && !a.back()) a.pop_back();
	}
	int size() const { return a.size(); }
	int operator[](int idx) const {
		if (idx < 0 || idx >= size())
			return 0;
		return a[idx];
	}
	Poly mulxk(int k) const {
		auto b = a;
		b.insert(b.begin(), k, 0);
		return Poly(b);
	}
	Poly modxk(int k) const {
		k = std::min(k, size());
		return Poly(std::vector<int>(a.begin(), a.begin() + k));
	}
	Poly divxk(int k) const {
		if (size() <= k)
			return Poly();
		return Poly(std::vector<int>(a.begin() + k, a.end()));
	}
	friend Poly operator+(const Poly a, const Poly& b) {
		std::vector<int> res(std::max(a.size(), b.size()));
		for (int i = 0; i < int(res.size()); ++i) {
			res[i] = a[i] + b[i];
			if (res[i] >= mod)
				res[i] -= mod;
		}
		return Poly(res);
	}
	friend Poly operator-(const Poly a, const Poly& b) {
		std::vector<int> res(std::max(a.size(), b.size()));
		for (int i = 0; i < int(res.size()); ++i) {
			res[i] = a[i] - b[i];
			if (res[i] < 0)
				res[i] += mod;
		}
		return Poly(res);
	}
	friend Poly operator*(Poly a, Poly b) {
		int sz = 1, tot = a.size() + b.size() - 1;
		while (sz < tot) sz *= 2;
		a.a.resize(sz);
		b.a.resize(sz);
		dft(a.a);
		dft(b.a);
		for (int i = 0; i < sz; ++i) a.a[i] = 1ll * a[i] * b[i] % mod;
		idft(a.a);
		return Poly(a.a);
	}
	Poly& operator+=(Poly b) { return (*this) = (*this) + b; }
	Poly& operator-=(Poly b) { return (*this) = (*this) - b; }
	Poly& operator*=(Poly b) { return (*this) = (*this) * b; }
	Poly deriv() const {  // 求導
		if (a.empty())
			return Poly();
		std::vector<int> res(size() - 1);
		for (int i = 0; i < size() - 1; ++i) res[i] = 1ll * (i + 1) * a[i + 1] % mod;
		return Poly(res);
	}
	Poly integr() const {  // 積分
		if (a.empty())
			return Poly();
		std::vector<int> res(size() + 1);
		for (int i = 0; i < size(); ++i) res[i + 1] = 1ll * a[i] * powmod(i + 1, mod - 2) % mod;
		return Poly(res);
	}
	Poly inv(int m) const {  // 逆
		Poly x(powmod(a[0], mod - 2));
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (2 - modxk(k) * x)).modxk(k);
		}
		return x.modxk(m);
	}
	Poly log(int m) const { return (deriv() * inv(m)).integr().modxk(m); }
	Poly exp(int m) const {
		Poly x(1);
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (1 - x.log(k) + modxk(k))).modxk(k);
		}
		return x.modxk(m);
	}
	Poly mulT(Poly b) const {  // 卷積
		if (b.size() == 0)
			return Poly();
		int n = b.size();
		std::reverse(b.a.begin(), b.a.end());
		return ((*this) * b).divxk(n - 1);
	}
	std::vector<int> eval(std::vector<int> x) const {  // 求值
		if (size() == 0)
			return std::vector<int>(x.size(), 0);
		const int n = std::max(int(x.size()), size());
		std::vector<Poly> q(4 * n);
		std::vector<int> ans(x.size());
		x.resize(n);
		std::function<void(int, int, int)> build = [&](int p, int l, int r) {
			if (r - l == 1) {
				q[p] = std::vector<int>{ 1, (mod - x[l]) % mod };
			}
			else {
				int m = (l + r) / 2;
				build(2 * p, l, m);
				build(2 * p + 1, m, r);
				q[p] = q[2 * p] * q[2 * p + 1];
			}
		};
		build(1, 0, n);
		std::function<void(int, int, int, const Poly&)> work = [&](int p, int l, int r, const Poly& num) {
			if (r - l == 1) {
				if (l < int(ans.size()))
					ans[l] = num[0];
			}
			else {
				int m = (l + r) / 2;
				work(2 * p, l, m, num.mulT(q[2 * p + 1]).modxk(m - l));
				work(2 * p + 1, m, r, num.mulT(q[2 * p]).modxk(r - m));
			}
		};
		work(1, 0, n, mulT(q[1].inv(n)));
		return ans;
	}
};

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	cout.tie(nullptr);
	string s1, s2;
	cin >> s1 >> s2;
	reverse(s2.begin(), s2.end());
	int n = s1.length();
	int m = s2.length();
	vector<char> st;
	vector<int> cnt(n - m + 1, 0);
	for (auto c : s1)
		st.push_back(c);
	for (auto c : s2)
		st.push_back(c);
	sort(st.begin(), st.end());
	st.erase(unique(st.begin(), st.end()), st.end());

	auto sol = [&](string s1, string s2, char ch) {
		vector<int> va(n, 0);
		vector<int> vb(m, 0);
		for (int i = 0; i < n; i++) {
			if (s1[i] == ch) {
				va[i] = 1;
			}
		}
		for (int i = 0; i < m; i++) {
			if (s2[i] == ch) {
				vb[i] = 1;
			}
		}
		Poly pa(va);
		Poly pb(vb);
		pa *= pb;
		for (int i = m - 1; i <= n - 1; i++)
			cnt[i - m + 1] += pa[i];
	};

	for (auto c : st)
		sol(s1, s2, c);

	vector<int> res;
	for (int i = 0; i <= n - m; i++) {
		if (cnt[i] == m) {
			res.push_back(i + 1);
		}
	}
	cout << (int)res.size() << '\n';
	for (int i = 0; i < res.size(); i++)
		cout << res[i] << " \n"[i == (int)res.size() - 1];
	return 0;
}

題意:給定兩個只由 \(A,C,G,T\) 構成的字符串 \(S,T\) 和一個門限值 \(k\) ,詢問字符串 \(T\)\(S\) 中出現的次數(即匹配次數)。咱們定義兩個字符串在位置 \(j\) 是匹配的當且僅當,對於 \(T\) 中的任意一個字符 \(T_i\) 都至少有一個字符 \(S_k\) 知足 \(S_k=T_i\),其中 \(j+i-l\leq k\leq j+i+k\)

分析:這題因爲字符集較小(大小爲 \(4\) ),所以直接考慮使用上方給出的方法,枚舉字符集中的元素求解。可是本題引入了一個門限值的概念,基礎的字符串匹配問題實際上詢問的是本題中 \(k=0\) 時的特例,那麼對於這個拓展問題咱們如何解決?

本題中的門限值實際上就是將一個字符的可匹配範圍向兩端均延申了 \(k\) 個單位,因而咱們天然能夠想到:只須要將主串中的待匹配字符均向兩端延申 \(k\) 個單位,而後再用模式串進行匹配便可。以本題樣例爲例,假設咱們正在枚舉字符 \(A\)

本題的兩個字符串 AGCAATTCATACAT 先預處理成 \(01\) 多項式:10011000101010

而後將主串中的待匹配字符向外延申 \(k(k=1)\) 個單位:1111110111 ,而後與 1010 進行卷積便可。

AC代碼https://codeforces.com/contest/528/submission/94462497

與此題相似的還有 Gym101667H Rock Paper Scissors。

CF827E Rusty String

題意:給定一個由 \(3\) 種字符 \(V,K,?\) 構成的字符串 \(s\),其中 \(?\) 能夠表示 \(V,K\) 中的任意一種,詢問該字符串全部可能的循環節長度。

分析:咱們考慮求出全部不符合要求的循環節長度,若是一個長度爲 \(d\) 的循環節不知足題意,那麼必然存在 \(s_i\neq s_{i+d}\) ,所以構造一個多項式 \(C(x)\) ,其係數 \(c_j = \sum_{i=0}^{n-1}[s_i\neq s_{i+j}]\) ,若是 \(c_j>0\) 那就說明長度爲 \(j\) 的循環節不知足題意。

因爲字符集較小,咱們仍然考慮使用 \(01\) 多項式的方法解決本題,須要注意的一點是本題中的 \(?\) 能夠近似地看做通配符,所以直接置爲零。

\[\begin{aligned} c_j&=\sum_{i=0}^{n-1}[s_i\neq s_{i+j}] \\ &= \sum_{i=0}^{n-1}s_is'_{n-1-i-j}\\ \end{aligned} \]

上式中,\(s_i\) 表示字符串 \(s\) 中的 \(V\) 所有置 \(1\) 的多項式係數,\(s'_i\) 表示字符串 \(s\) 翻轉後將 \(K\) 所有置 \(1\) 的多項式係數。

而後你就會寫出一個樣例都過不去的代碼,這是由於本題中的 \(?\) 並不等價於通配符,例如樣例中長度爲 \(2\) 的循環節會出現如下狀況:

V ? ? V K

N N V ? ? V K (N表示空位)

若是 \(?\) 爲通配符,那麼本題確實能夠成功匹配,可是本題中 \(?\) 並不能同時爲 \(V\)\(K\) 。本例中出現了 \(s_0=s_2=s_4\) 的狀況,所以存在矛盾。

咱們考慮一下爲何會出現這個問題,這是由於咱們只判斷了 \(s_i\)\(s_{i+d}\) 的關係,可是這個關係並不具備傳遞性,\(s_i=s_{i+d}=s_{i+2d}=\cdots = s_{i+kd}\) 不能遞推獲得 \(d\) 是合法的循環節長度。咱們利用這個性質來完善咱們的方法:對於一個長度爲 \(d\) 的循環節,若是它是合法的,那麼全部長度爲 \(kd(k>1)\) 的循環節也必須合法。所以咱們能夠利用埃氏篩對每一個位置進行可行性檢測。

這裏的實現還須要注意一個 \(wa\) 點,長度爲 \(d\) 的循環節不只對應卷積後的 \(c_{n-1-d}\) ,同時也對應了 \(c_{n-1+d}\) 。舉個例子進行說明:

設字符串 \(s=K?V\) ,那麼通過處理後獲得:\(a=0,0,1;\ b=0,0,1\) ;卷積後的結果爲:\(c=0,0,0,0,1\) ,若是你在判斷 \(d=1\) 時漏判了 \(c_{n-1+d}\) 就會致使答案錯誤,由於 \(c_0\)\(c_4\) 實際上都是 \(d=2\) 的卷積係數。

AC代碼https://codeforces.com/contest/827/submission/94561373

相關文章
相關標籤/搜索