subsequnce----dp

subsequence

題意:給長度爲\(n\), \(m\)的字符串\(s\), \(t\), 字符串由0~9的數字組成,問在十進制意義下\(s\)中比\(t\)串大的子序列個數。c++

\(m\leq n \leq{3000}\).spa

題解:考慮兩種不一樣狀況:子序列長度等於\(t\)串以及子序列長度大於\(t\)串。用\(len[i][j]\)維護\(s\)串中第\(i\)位之前長度爲\(j\)的合法串(無前導零)個數,那麼長度大於\(t\)串的個數爲\(\sum_{i=m+1}^{n} len[n][i]\).code

\(dp1[i][j]\)維護\(s\)串中第\(i\)位之前長度爲\(j\)且嚴格大於\(t\)串中前\(j\)位的子序列個數, \(dp2[i][j]\)維護\(s\)串中第\(i\)位之前長度爲\(j\)且大於等於\(t\)串前\(j\)位的子序列個數,很容易由\(s[i]\)\(t[j]\)的大小關係得出一系列轉移方程。 因而長度等於\(t\)串的個數爲\(dp1[n][m]\).ci

代碼:字符串

#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353;
char a[3005], b[3005];
int n, m;
int dp1[3005][3005], len[3005][3005], dp2[3005][3005];
int main() {
    int T;
    cin >> T;
    while(T--) {
        scanf("%d%d", &n, &m);
        int ans1 = 0, ans2 = 0;
        int maxx = max(n, m);
        for(int i = 0; i <= maxx + 2; i++) for(int j = 0; j <= maxx + 2; j++) dp1[i][j] = dp2[i][j] = 0, len[i][j] = 0;
        scanf("%s%s", a + 1, b + 1);
        dp1[0][0] = 0;
        dp2[0][0] = 1;
        len[0][0] = 1;
        for(int i = 1; i <= n; i++) {
            len[i][0] = 1;
            for(int j = 1; j <= i; j++) len[i][j] = (len[i - 1][j - 1] + len[i - 1][j]) % mod;
            if(a[i] == '0') len[i][1] = len[i - 1][1];
            dp2[i][0] = 1;
            for(int j = 1; j <= i; j++) {
                dp1[i][j] = dp1[i - 1][j];
                dp2[i][j] = dp2[i - 1][j];
                if(a[i] <= b[j]) {
                    dp1[i][j] = (dp1[i][j] + dp1[i - 1][j - 1]) % mod;
                    if(a[i] == b[j]) dp2[i][j] = (dp2[i][j] + dp2[i - 1][j - 1]) % mod;
                    else dp2[i][j] = (dp2[i][j] + dp1[i - 1][j - 1]) % mod;
                }
                if(a[i] > b[j]) {
                    dp1[i][j] = (dp1[i][j] + dp2[i - 1][j - 1]) % mod;
                    dp2[i][j] = (dp2[i][j] + dp2[i - 1][j - 1]) % mod;
                }
                //dp2[i][j] += dp1[i - 1][j - 1];
            }
        }
        for(int i = m + 1; i <= n; i++) ans1 = (ans1 + len[n][i]) % mod;
        ans2 = dp1[n][m];
        int ans = (ans1 + ans2) % mod;
        printf("%d\n", ans);
    }
    return 0;
}
本站公眾號
   歡迎關注本站公眾號,獲取更多信息