[SDOI2015]序列統計

嘟嘟嘟


此題很可作。


首先從一個暴力的dp入手:令\(dp[i][j]\)表示第\(i\)個數爲\(j\)時的數列個數,因而有\(dp[i][j *a[k] \% M] += dp[i - 1][j]\)
但這個彷佛只能拿10分。


一個顯然的優化是改爲倍增快速冪的形式,上述dp方程顯然是能夠合併的,即\(dp[x + y][i * j \% M] += dp[x][i] * dp[y][j]\)
乘的時候暴力乘,這樣\(O(k ^ 2logn)\)能拿到60分。


如今瓶頸在於多項式的乘法。這東西和卷積很像,只不過卷積是加,他倒是乘。那麼怎麼才能把乘變成加呢?取log啊!
而後我就卡在了這裏:取完log下標都不是整數,那不gg了……


最後問了衡水的巨佬,他說你對原根取對數啊!問了一大頓纔想起來,原根的定義是一個數\(g\),知足\(g ^ 0, g ^ 1, g ^ 2 \ldots g ^ {p - 2}\)恰好能湊出\([1, p - 1]\)的全部整數。因而這題就完事了啊!
看來仍是本身原根學的很差,一下子趕快複習一下。


數據範圍挺可愛的,規定了\(x\)不能夠取\(0\),要否則還得分來討論把\(0\)單出來算。
然而集合\(S\)中卻有\(0\)……這得特判一下……ios

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<queue>
#include<vector>
#include<ctime>
#include<assert.h>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
#define forE(i, x, y) for(int i = head[x], y; (y = e[i].to) && ~i; i = e[i].nxt)
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e3 + 5;
const int maxM = 4e4 + 5;
const ll mod = 1004535809;
const ll G = 3;
In ll read()
{
    ll ans = 0;
    char ch = getchar(), las = ' ';
    while(!isdigit(ch)) las = ch, ch = getchar();
    while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
    if(las == '-') ans = -ans;
    return ans;
}
In void write(ll x)
{
    if(x < 0) putchar('-'), x = -x;
    if(x >= 10) write(x / 10);
    putchar(x % 10 + '0');
}
In void MYFILE()
{
#ifndef mrclr
    freopen("ha.in", "r", stdin);
    freopen("bf.out", "w", stdout);
#endif
}

int n, M, X, S, a[maxM], pos[maxM];
int len = 1, lim = 0, rev[maxM];

In ll inc(ll a, ll b) {return a + b < mod ? a + b : a + b - mod;}
In ll quickpow(ll a, ll b, ll mod)
{
    ll ret = 1;
    for(; b; b >>= 1, a = a * a % mod)
        if(b & 1) ret = ret * a % mod;
    return ret;
}

In ll phi(ll n)
{
    ll ret = n;
    for(int i = 2; i * i <= n; ++i)
    {
        if(n % i) continue;
        ret = ret / i * (i - 1);
        while(n % i == 0) n /= i;
    }
    if(n > 1) ret = ret / n * (n - 1);
    return ret;
}
int p[1000], pcnt = 0;
In ll getRoot(ll m)
{
    ll Phi = phi(m); pcnt = 0;
    for(int i = 2; i * i <= Phi; ++i) if(Phi % i == 0)
    {
        p[++pcnt] = i;
        if(Phi / i != i) p[++pcnt] = Phi / i;
    }
    for(int g = 2; g <= Phi; ++g)
    {
        bool flg = 1;
        if(quickpow(g, Phi, m) ^ 1) continue;
        for(int i = 1; i <= pcnt && flg; ++i)
            if(quickpow(g, p[i], m) == 1) flg = 0;
        if(flg) return g;
    }
    return -1;
}
In void init()
{
    int g = getRoot(M), tp = 1;
    for(int i = 0; i < M - 1; ++i, tp = tp * g % M) pos[tp] = i;
    while(len < M + M) len <<= 1, ++lim;
    for(int i = 0; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lim - 1));
}

In void ntt(ll* a, int len, bool flg)
{
    for(int i = 0; i < len; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]);
    for(int i = 1; i < len; i <<= 1)
    {
        ll gn = quickpow(G, (mod - 1) / (i << 1), mod);
        for(int j = 0; j < len; j += (i << 1))
        {
            ll g = 1;
            for(int k = 0; k < i; ++k, g = g * gn % mod)
            {
                ll tp1 = a[j + k], tp2 = a[j + k + i] * g % mod;
                a[j + k] = (tp1 + tp2) % mod, a[j + k + i] = (tp1 - tp2 + mod) % mod;
            }
        }
    }
    if(flg) return;
    reverse(a + 1, a + len); ll inv = quickpow(len, mod - 2, mod);
    for(int i = 0; i < len; ++i) a[i] = a[i] * inv % mod;
}

ll c[maxM], A[maxM], B[maxM];
In void mul(ll* a, ll* b)
{
    for(int i = 0; i < len; ++i)    //必定要複製到另外一個數組再NTT!由於傳過來的數組a和b多是同一個!(debug到頭禿)
    {
        A[i] = i < M - 1 ? a[i] : 0;
        B[i] = i < M - 1 ? b[i] : 0;
    }
    ntt(A, len, 1), ntt(B, len, 1);
    for(int i = 0; i < len; ++i) a[i] = A[i] * B[i] % mod;
    ntt(a, len, 0);
    for(int i = 0; i < M - 1; ++i) a[i] = inc(a[i], a[i + M - 1]);
}

ll f[maxM], g[maxM];
In ll Quickpow(int n)
{
    f[pos[1]] = 1;
    for(int i = 1; i <= S; ++i) if(a[i]) g[pos[a[i]]] = 1;
    for(; n; n >>= 1, mul(g, g)) 
        if(n & 1) mul(f, g);
    return f[pos[X]];
}

int main()
{
//  MYFILE();
    n = read(), M = read(), X = read(), S = read();
    for(int i = 1; i <= S; ++i) a[i] = read();
    init();
    write(Quickpow(n)), enter;
    return 0;
}
相關文章
相關標籤/搜索