【ZJOI 2019】麻將(dp of dp)

這是我第一次寫$dp \; of \; dp$,大體思路參考了xyx的作法,可能有些地方不太同樣,但也許會詳細一點。c++

考慮給你一副牌,如何判斷這副牌是不是胡的。ide

容易發現相同的順子不會選三個以上,因而考慮花色從小到大進行$dp$。令$dp_{0/1, i, j, k}$表示是否有對子,考慮了前$i$種花色的牌,以花色$i - 1$的牌爲開頭的順子準備選$j$個,以花色$i$的牌爲開頭的順子準備選$k$個,此時能選的最大的面子數。轉移只要枚舉以花色$i + 1$的牌開頭的順子準備選幾個,剩下的牌組成刻子就好了(具體細節能夠看代碼中結構體$State$中的$Trans$函數)。函數

斷定胡牌的條件有兩種:$\exists (j,k), dp_{1, n, j, k} \ge 4$;記$cnt$爲牌數$\ge 2$的花色數,$cnt \ge 7$。spa

能夠發現這個$dp$和狀態$i$這一維關係不是很大,$dp_{0}$和$dp_{1}$的轉移方式也都是同樣的。只考慮有$dp_{*,j,k}$這個$3 \times 3$的矩陣,其轉移本質上是接收一個數字(表示下一種花色的牌的數量),而後變成了另外一個$3 \times 3$的矩陣,咱們能夠把它當作一個自動機的模型。事實上這樣的$dp$狀態($3 \times 3$的矩陣)並不不少。咱們用結構體$State$來形容這個矩陣。爲了表示整個當前胡牌狀態,咱們須要兩個$State$分別表示有沒有選對子(即$dp_0$和$dp_1$),以及$cnt$。咱們把這三個東西放到結構體$Mahjong$中,咱們很容易能夠獲得$Mahjong$的$Trans$關係。經搜索$Mahjong$的狀態數總共有$3956$個。code

回過來看咱們要算的答案:$ans = \sum\limits_{a = 13}^{4n} p(a)$,其中$p(a)$表示你總共摸了$a$張牌以後仍然不能胡的機率。因而咱們只要算摸了$a$張牌仍然不能胡的排列數便可。令$f_{i, j, k}$表示咱們考慮了前$i$種花色的牌,當前胡牌狀態爲$j$($j$是一個$Mahjong$類),已經摸了$k$張牌的排列數有多少。轉移枚舉摸了花色$i + 1$的牌數$z$,則由$f_{i, j, k}$轉移到$f_{i + 1, trans(j,z), k + z}$,乘上的係數就是$(4 - org_{i + 1})^{\underline{z - org{i + 1}}} \binom{k + z - sum_{i + 1}}{z - org_{i + 1}}$。$org_i$表示原有的$13$張牌中花色爲$i$的有幾張,$sum$則是$org$的前綴和,這個式子的意思就是咱們須要在沒被選過的$4 - org_{i + 1}$張牌中選$z - org_{i + 1}$張的排列,而且插入到以前的排列中,但前$13$張牌的順序是固定的。blog

最後$p(a)$是很好算的,$p(a) = \frac{\sum\limits_{j \; can \; not \; win } f_{n, j, a}}{ (4n - 13)^{\underline{a - 13}} }$。it

複雜度爲$O(3956 * n^2)$。event

#include <bits/stdc++.h>

using namespace std;

const int N = 1e2 + 5;
const int M = 4e3 + 5;
const int MOD = 998244353;

namespace {
  int ch[M][5], fac[5];
  int Add(int a, int b) { return (a += b) >= MOD? a - MOD : a; }
  void Upd(int &a, int b) { a = Add(a, b); }
  int Mul(int a, int b) { return (long long)a * b % MOD; }
  int Inv(int x) { return (x == 1)? 1 : Mul(MOD - MOD / x, Inv(MOD % x)); }
  void Math_init() {
    for (int i = 0; i < M; ++i) {
      ch[i][0] = 1;
      for (int j = 1; j <= min(i, 4); ++j) {
        ch[i][j] = Add(ch[i - 1][j - 1], ch[i - 1][j]);
      }
    }
    fac[0] = 1;
    for (int i = 1; i < 5; ++i) {
      fac[i] = Mul(fac[i - 1], i);
    }
  }
}

bool Chkmax(int &a, int b) {
  return (a < b)? (a = b, 1) : (0);
}

struct State {
  int dp[3][3]; // (last last bar, last bar)
  State() {
    memset(dp, -1, sizeof dp);
  }
  
  friend bool operator < (State a, State b) {
    for (int i = 0; i < 3; ++i)
      for (int j = 0; j < 3; ++j)
        if (a.dp[i][j] != b.dp[i][j])
          return a.dp[i][j] < b.dp[i][j];
    return 0;
  }
  
  friend State Max(State a, State b) {
    for (int i = 0; i < 3; ++i) 
      for (int j = 0; j < 3; ++j)
        a.dp[i][j] = max(a.dp[i][j], b.dp[i][j]);
    return a;
  }
  
  friend State Trans(State a, int b) {
    State c;
    for (int i = 0; i < 3; ++i)
      for (int j = 0; j < 3; ++j)
        if (~a.dp[i][j])
          for (int k = 0; k < 3 && i + j + k <= b; ++k)
            Chkmax(c.dp[j][k], min(i + a.dp[i][j] + (b - i - j - k) / 3, 4));
    return c;
  }
  
};

struct Mahjong {
  pair<State, State> god;
  int cnt;
  Mahjong() {
    memset(god.first.dp, -1, sizeof god.first.dp);
    memset(god.second.dp, -1, sizeof god.second.dp);
    god.first.dp[0][0] = cnt = 0;
  }
  
  friend bool operator < (Mahjong a, Mahjong b) {
    return a.cnt != b.cnt? a.cnt < b.cnt : a.god < b.god;
  }
  
  friend Mahjong Trans(Mahjong a, int b) {
    a.cnt = min(a.cnt + (b >= 2), 7);
    a.god.second = Trans(a.god.second, b);
    if (b >= 2) {
      a.god.second = Max(a.god.second, Trans(a.god.first, b - 2));
    }
    a.god.first = Trans(a.god.first, b);
    return a;
  }
  
  bool right() {
    if (cnt == 7) return 1;
    for (int i = 0; i < 3; ++i)
      for (int j = 0; j < 3; ++j)
        if (god.second.dp[i][j] == 4) return 1;
    return 0;
  }
  
} mahjong[M];


int n, tot;
map<Mahjong, int> idx;
bool win[M];
int org[N], f[N][M][4 * N], trans[M][5];

void Dfs_mahjong(Mahjong now) {
  if (idx.find(now) != idx.end()) return;
  mahjong[++tot] = now;
  win[tot] = now.right();
  idx[now] = tot;
  for (int i = 0; i <= 4; ++i) {
    Dfs_mahjong(Trans(now, i));
  }
}

int main() {
  Math_init();
  Dfs_mahjong(Mahjong());

  for (int i = 1; i <= tot; ++i) {
    for (int j = 0; j <= 4; ++j) {
      trans[i][j] = idx[Trans(mahjong[i], j)];
    }
  }

  scanf("%d", &n);
  for (int i = 0, x; i < 13; ++i) {
    scanf("%d%*d", &x);
    ++org[x];
  }

  f[0][1][0] = 1;
  for (int i = 0, cp = 0; i < n; ++i) { // consider 1 ... i
    cp += org[i + 1];
    for (int j = 1; j <= tot; ++j) { // mahjong j
      for (int l = org[i + 1]; l <= 4; ++l) { // trans
        int *nf = f[i + 1][trans[j][l]], *ff = f[i][j];
        int tmp = Mul(ch[4 - org[i + 1]][l - org[i + 1]], fac[l - org[i + 1]]);
        for (int k = 0; k + l <= 4 * n; ++k) { // have chosen k cards
          if (!ff[k]) continue;
          Upd(nf[k + l], Mul(ff[k], Mul(ch[k + l - cp][l - org[i + 1]], tmp)));
        }
      }
    }
  }
  
  int ans = 0, dw = 1;
  for (int i = 13; i <= 4 * n; ++i) {
    int up = 0;
    for (int j = 1; j <= tot; ++j) {
      if (!win[j]) Upd(up, f[n][j][i]);
    }
    Upd(ans, Mul(up, Inv(dw)));
    dw = Mul(dw, 4 * n - i);
  }
  printf("%d\n", ans);
  
  return 0;
}
View Code
相關文章
相關標籤/搜索