【算法】Matrix - Tree 矩陣樹定理 & 題目總結

  最近集中學習了一下矩陣樹定理,本身其實仍是沒有太明白原理(證實)類的東西,但想在這裏總結一下應用中的一些細節,矩陣樹定理的一些引伸等等。html

  首先,矩陣樹定理用於求解一個圖上的生成樹個數。實現方式是:\(A\)爲鄰接矩陣,\(D\)爲度數矩陣,則基爾霍夫(Kirchhoff)矩陣即爲:\(K = D - A\)。具體實現中,記 \(a\) 爲Kirchhoff矩陣,則若存在 \(E(u, v)\) ,則\(a[u][u] ++, a[v][v] ++, a[u][v] --, a[v][u] --\) 。即\(a[i][i]\) 爲 \(i\) 點的度數,\(a[i][j]\) 爲 \(i, j\)之間邊的條數的相反數。ios

  這樣構成的矩陣的行列式的值,就爲生成樹的個數。而求解行列式的快速方法爲使用高斯消元進行消元消處上三角矩陣,則有對角線上的值的乘積 = 行列式的值。通常而言求解生成樹個數的題目數量會很是龐大,須要取模處理。取模處理中,不能出現小數,因而使用展轉相除法:(其中由於消的是行列式,因此與消方程有所不一樣。交換兩行行列式的值變號,且消元只能將一行的數 * k 以後加到別的行上。)c++

int Gauss()
{
    int ans = 1;
    for(int i = 1; i < tot; i ++)
    {
        for(int j = i + 1; j < tot; j ++)
            while(f[j][i])
            {
                int t = f[i][i] / f[j][i];
                for(int k = i; k < tot; k ++)
                    f[i][k] = (f[i][k] - t * f[j][k] + mod) % mod;
                swap(f[i], f[j]);
                ans = - ans;
            }
        ans = (ans * f[i][i]) % mod;
    }
    return (ans + mod) % mod;
}

  變元矩陣樹定理:求全部生成樹的總邊積的和。和矩陣樹的求法相同,不過行列式中 \(a[i][i]\) 記錄的是總的邊權和,\(a[i][j]\) 記錄 \(i, j\) 之間邊權的相反數。學習

  如下爲幾道題目:spa

    1.HEOI2015 小Z的房間     2.SHOI2016 黑暗前的幻想鄉code

    3.SDOI2014 重建         4.JSOI2008 最小生成樹計數htm

  1.HEOI2015 小Z的房間(妥妥的模板題一個)blog

#include <bits/stdc++.h>
using namespace std;
#define maxn 90
#define int long long 
#define mod 1000000000
int n, m, f[maxn][maxn];
int tot, Map[maxn][maxn];

int read()
{
    int x = 0, k = 1;
    char c;
    c = getchar();
    while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); }
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * k;
}

void add(int x, int y)
{
    if(x > y) return;
    f[x][x] ++, f[y][y] ++;
    f[x][y] --, f[y][x] --;
}

int Gauss()
{
    int ans = 1;
    for(int i = 1; i < tot; i ++)
    {
        for(int j = i + 1; j < tot; j ++)
            while(f[j][i])
            {
                int t = f[i][i] / f[j][i];
                for(int k = i; k < tot; k ++)
                    f[i][k] = (f[i][k] - t * f[j][k] + mod) % mod;
                swap(f[i], f[j]);
                ans = - ans;
            }
        ans = (ans * f[i][i]) % mod;
    }
    return (ans + mod) % mod;
}

signed main()
{
    n = read(), m = read();
    for(int i = 1; i <= n; i ++)
    {
        char c;
        for(int j = 1; j <= m; j ++)
        {
            cin >> c;
            if(c == '.') Map[i][j] = ++ tot;
        }
    }
    for(int i = 1; i <= n; i ++)
        for(int j = 1; j <= m; j ++)
        {
            int tem, u;
            if(!(u = Map[i][j])) continue;
            if(tem = Map[i - 1][j]) add(u, tem);
            if(tem = Map[i + 1][j]) add(u, tem);
            if(tem = Map[i][j - 1]) add(u, tem);
            if(tem = Map[i][j + 1]) add(u, tem);
        }
    printf("%lld\n", Gauss());
    return 0;
}

  2.SHOI2016黑暗前的幻想鄉ci

  容斥+矩陣樹定理。與模板的不一樣之處在於每一家公司都要參與修建,則合法方案數 = 總的方案數 - 有一個公司未修建的方案數 + 有兩個公司未修建的方案數……暴力重構矩陣求解便可。get

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int mod = 1000000007;
int n;
ll g[20][20];
vector<pair<int , int > > q[20];

int read()
{
    int x = 0, k = 1;
    char c;
    c = getchar();
    while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); }
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * k;
}

int Gauss()
{
    ll ans = 1;
    for(int i = 1; i < n; i ++)
    {
        for(int j = i + 1; j < n; j ++)
            while(g[j][i])
            {
                ll t = g[i][i] / g[j][i];
                for(int k = i; k < n; k ++)
                    g[i][k] = (g[i][k] - g[j][k] * t) % mod;
                swap(g[i], g[j]);
                ans = -ans;
            }
        ans = (ans * g[i][i]) % mod;
        if(!ans) return 0;
    }
    return (ans + mod) % mod;
}

int main()
{
    n = read();
    for(int i = 1; i < n; i ++)
    {
        int m = read();
        for(int j = 1; j <= m; j ++)
        {
            int x = read(), y = read();
            q[i].push_back(make_pair(x, y));
        }
    }
    int ans = 0, CNST = 1 << (n - 1);
    for(int i = 0; i < CNST; i ++)
    {
        int cnt = 0; memset(g, 0, sizeof(g));
        for(int j = 1; j < n; j ++)
            if(i & (1 << (j - 1)))
            {
                for(int k = 0; k < q[j].size(); k ++)
                {
                    int x = q[j][k].first, y = q[j][k].second;
                    g[x][x] ++, g[y][y] ++;
                    g[x][y] --, g[y][x] --;
                }
                cnt ++;
            }
        if((n - cnt) & 1) ans = (ans + Gauss()) % mod;
        else ans = (ans - Gauss() + mod) % mod;
    }
    printf("%d\n", ans);
    return 0;
}

  3.SDOI2014重建

  化式子 + 變元矩陣樹定理。將機率的式子寫出來變形便可獲得矩陣樹定理求 \(\prod \frac{p(u, v)}{1 - p(u, v)}\)

#include <bits/stdc++.h>
using namespace std;
#define maxn 100
#define db double
#define eps 0.000001
int n;
db ans = 1.0, a[maxn][maxn];

db Gauss(int n)
{
    db ans = 1.0;
    for(int i = 1; i <= n; i ++)
    {    
        for(int j = i + 1; j <= n; j ++)
        {
            int t = i;
            if(fabs(a[j][i]) > fabs(a[t][i])) t = j;
            if(t != i) swap(a[t], a[i]), ans = -ans;
        } 
        for(int j = i + 1; j <= n; j ++)
        {
            db t = a[j][i] / a[i][i];
            for(int k = i; k <= n; k ++)
                a[j][k] -= t * a[i][k];
        }
        ans *= a[i][i];
    }
    return fabs(ans);
}

int main()
{
    scanf("%d", &n);
    for(int i = 1; i <= n; i ++)
        for(int j = 1; j <= n; j ++)
            scanf("%lf", &a[i][j]);
    for(int i = 1; i <= n; i ++)
        for(int j = 1; j <= n; j ++)
        {
            db t = fabs(1.0 - a[i][j]) < eps ? eps : (1.0 - a[i][j]);
            if(i < j) ans *= t;
            a[i][j] = a[i][j] / t;
        }
    for(int i = 1; i <= n; i ++)
        for(int j = 1; j <= n; j ++)
            if(i != j) { a[i][i] += a[i][j], a[i][j] = -a[i][j]; }
    printf("%.10lf\n", Gauss(n - 1) * ans);
    return 0;
} 

  4.JSOI2008最小生成樹計數

  這題雖然最先年,然而也最強啊……我的認爲這位博主解釋得很好了 Z-Y-Y-S的博客

  兩個性質 mark 一下:

  

  

 

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
#define maxn 200
#define mod 31011
int n, m, ans = 1, tmp[maxn];
int sum, fa[maxn], set[maxn];
int a[maxn][maxn];

struct edge
{
    int u, v, w;
}E[maxn * 20], e[maxn * 20];

int read()
{
    int x = 0, k = 1;
    char c;
    c = getchar();
    while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); }
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * k;
}

bool cmp(edge a, edge b) { return a.w < b.w; }

int find(int x) { return set[x] == x ? x : set[x] = find(set[x]); }
int find2(int x) { return fa[x] == x ? x : fa[x] = find2(fa[x]); }

int Gauss(int n)
{
    int ans = 1;
    for(int i = 1; i <= n; i ++)
        for(int j = 1; j <= n; j ++)
            a[i][j] = (a[i][j] + mod) % mod;
    for(int i = 1; i <= n; i ++)
    {
        for(int j = i + 1; j <= n; j ++)
            while(a[j][i])
            {
                int t = a[i][i] / a[j][i];
                for(int k = i; k <= n; k ++)
                    a[i][k] = (a[i][k] - 1ll * t * a[j][k] % mod + mod) % mod;
                swap(a[i], a[j]); ans = - ans;
            }
        ans = 1ll * ans * a[i][i] % mod;
    }
    return (ans + mod) % mod;
}

void Cal(int S, int T)
{
    int cnt = 0;
    for(int i = S; i <= T; i ++)
    {
        e[i] = E[i];
        int p = find(e[i].u), q = find(e[i].v);
        e[i].u = p, e[i].v = q;
        if(p == q) continue;
        tmp[++ cnt] = p, tmp[++ cnt] = q;
    }
    sort(tmp + 1, tmp + 1 + cnt);
    cnt = unique(tmp + 1, tmp + cnt + 1) - tmp - 1;
    memset(a, 0, sizeof(a));
    for(int i = 1; i <= cnt; i ++) fa[i] = i;
    for(int i = S; i <= T; i ++)
    {
        if(e[i].u == e[i].v) continue;
        int p = find(e[i].u), q = find(e[i].v);
        if(p != q) -- sum, set[p] = q;
        int u = lower_bound(tmp + 1, tmp + cnt + 1, e[i].u) - tmp;
        int v = lower_bound(tmp + 1, tmp + cnt + 1, e[i].v) - tmp;
        a[u][u] ++, a[v][v] ++;
        a[u][v] --, a[v][u] --;
        p = find2(u), q = find2(v);
        if(p != q) fa[p] = q;
    }
    for(int i = 2; i <= cnt; i ++)
        if(find2(i) != find2(i - 1))
        {
            int p = find2(i), q = find2(i - 1);
            a[p][p] ++, a[q][q] ++;
            a[p][q] --, a[q][p] --;
            fa[p] = q;
        }
    ans = 1ll * ans * Gauss(cnt - 1) % mod;
}

int main()
{
    n = read(), m = read();
    for(int i = 1; i <= m; i ++)
        E[i].u = read(), E[i].v = read(), E[i].w = read();
    sort(E + 1, E + 1 + m, cmp);
    for(int i = 1; i <= n; i ++) set[i] = i;
    sum = n; 
    for(int i = 1, j; i <= m; i = j)
    {
        for(j = i; j <= m; j ++)
            if(E[j].w != E[i].w) break;
        if(j - i > 1) Cal(i, j - 1);
        else 
        {
            int p = find(E[i].u), q = find(E[i].v);
            if(p != q) set[p] = q;
            sum --;
        }
    }
    if(sum > 1) printf("0");
    else printf("%d\n", ans);
    return 0;
} 
相關文章
相關標籤/搜索