LibreOJ #2325. 「清華集訓 2017」小Y和恐怖的奴隸主(矩陣快速冪優化DP)

  哇這題劇毒,卡了很久常數才過T_Tios

  設$f(i,s)$爲到第$i$輪攻擊,怪物狀態爲$s$時對boss的指望傷害,$sum$爲狀態$s$所表示的怪物個數,獲得樸素的DP方程$f(i,s)=\sum \frac{1}{sum+1}*(f(i+1,s')+[s==s'])$git

  狀態數只有$C_{8+3}^3=165$個,因此就能夠矩乘優化了。再加上一個用於轉移的$1$,矩陣大小是$166*166$的,由於多組詢問,因此能夠先把$2$的全部次冪的矩陣都預處理出來。ide

  而後會發現複雜度是$O(T*166^3*N)$的,沒法承受...優化

  其實答案矩陣只有一列...用它從左往右乘就能把矩陣乘法優化到$O(166^2)$了,總時間複雜度$O(166^3*logn+T*166^2*logn)$spa

  $16$億過$2$秒,長見識了...code

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn=200, mod=998244353;
const ll inf=8226880250554875800;
struct mtx{int mp[maxn][maxn], n, m;mtx(){memset(mp, 0, sizeof(mp)); n=m=0;}}
base[60];
mtx operator * (mtx a, mtx b)
{
    mtx c; c.n=a.n; c.m=b.m;
    for(int i=0;i<=a.n;i++)
    for(int j=0;j<=b.m;j++)
    {
        ll s=0;
        for(int k=0;k<=a.m;k++)
        s+=1ll*a.mp[i][k]*b.mp[k][j], s>inf && (s%=mod);
        c.mp[i][j]=s%mod;
    }
    return c;
}
int T, m, K, tott;
ll n;
int st[maxn], mi[maxn], pos[1<<10];
inline int power(int a, int b)
{
    int ans=1;
    for(;b;b>>=1, a=1ll*a*a%mod)
    if(b&1) ans=1ll*ans*a%mod;
    return ans;
}
int main()
{
    scanf("%d%d%d", &T, &m, &K);
    mi[0]=1; for(int i=1;i<=m;i++) mi[i]=mi[i-1]*(K+1);
    for(int i=0;i<mi[m];i++) 
    {
        int sum=0;
        for(int j=0;j<m;j++) sum+=i/mi[j]%(K+1);
        if(sum<=K) st[tott]=i, pos[i]=tott++;
    }
    base[0].mp[tott][tott]=1;
    base[0].n=base[0].m=tott;
    for(int i=0;i<tott;i++)
    {
        int sum=0;
        for(int j=0;j<m;j++) sum+=st[i]/mi[j]%(K+1);
        int inv=power(sum+1, mod-2);
        base[0].mp[i][tott]=base[0].mp[i][i]=inv;
        for(int j=0;j<m;j++)
        if(st[i]/mi[j]%(K+1))
        {
            int x=st[i]-mi[j];
            if(j) x+=mi[j-1];
            if(j && sum<K) x+=mi[m-1];
            base[0].mp[i][pos[x]]=1ll*inv*(st[i]/mi[j]%(K+1))%mod;
        }
    }
    for(int i=1;i<60;i++) base[i]=base[i-1]*base[i-1];
    while(T--)
    {
        scanf("%lld", &n); mtx ans; ans.n=tott; ans.mp[tott][0]=1;
        int digit=0; for(;n;n>>=1, digit++) if(n&1) ans=base[digit]*ans;
        printf("%d\n", ans.mp[pos[mi[m-1]]][0]);
    }
}
View Code
相關文章
相關標籤/搜索