【OI】倍增求LCA

 

╭(′▽`)╯ide

 

總之,咱們都知道lca是啥,不須要任何基礎也能想出來怎麼用最暴力的方法求LCA,也就是深度深的點先跳到深度淺的點的同一深度,而後一塊兒向上一步步跳。這樣顯然太慢了!spa

因此咱們要用倍增,倍增比較屌,直接2^k速度往上跳,並且複雜度和樹剖lca差很少,那麼步驟分爲兩步code

1.讓兩個點到同一深度blog

2.到了同一深度同步往上跳get

反正我一開始看的時候一直在想,萬一跳過了怎麼辦?哈哈哈,因此說咱們有辦法嘛:同步

 

定義deepv爲v點的深度,設兩個要求lca的點分別爲a,b,且deepa >= deepbstring

因此,枚舉找出最大的k使2^k <= deepa,這就是最大的跳的距離;io

接着讓他們到達同一深度:event

從大到小枚舉k,若是 deepa - 2^k >= deepb就往上跳2^k步,由於若是跳了2^k步的話必定deepa >= deepb模板

因此,咱們跳的第一步必定是能跳的最大的一步,因此接下來只能跳次大的一步,同理跳完以後deepa >= deepb

......

由於k是愈來愈小的,k = 0的時候2^k = 1,所以不管如何最後都會以最大的效率跳到相同的深度

如今跳到了相同的深度,而後要同時向上走找到lca。

假設跳了 2 ^ k步以後它們到的位置不相等,說明lca還在深度更淺的地方,由於若是跳以後到的位置相等了,顯然這個位置必定在lca的上面

因此,只要判斷跳了 2 ^ k步後它們的位置若是不相等,就跳這步,這樣就保證了跳到的深度必定小於lca,最後k = 0時 2 ^ k = 1,

則枚舉完了k,它們所在的深度顯然必定是lca的深度-1,則lca就是它們任意一個的父親。

 

代碼(luogu lca模板):

#include <cstdio>
#include <vector>
#include <cstring>

const int MaxN = 500010;

int n,m,s;
int par[MaxN][30];
int deep[MaxN];
bool vis[MaxN];

struct Edge{
    int to,nxt;
}e[MaxN*2];
int head[MaxN];
int cnt;

void add(int u,int v){
    e[++cnt].to = v;
    e[cnt].nxt = head[u];
    head[u] = cnt;
}

void getdeep(int u){
    vis[u] = 1;
    for(int i = head[u]; i; i = e[i].nxt){
        
        int to = e[i].to;
        if(to == u || vis[to]) continue;
        
        par[to][0] = u;
        
        deep[to] = deep[u] + 1;
        
        getdeep(to);
        
    }
    
}

void getpar(){
    for(int up = 1; (1<<up) <= n; up++){
        for(int i = 1; i <= n ; i++){
            par[i][up] = par[par[i][up-1]][up-1];
        }
        
    }
    
}


int lca(int u,int v){
    if(deep[u] < deep[v] ) std::swap(u,v);
    
    int max_jump = -1;
    
    while(1<<(max_jump+1) <= deep[u]) max_jump++;
    
    for(int i = max_jump; i >= 0; i--){
        if(deep[u] - (1<<i) >= deep[v]){
            u = par[u][i];
        }
        
    }
    
    if(u == v)
        return u;
        
    for(int i = max_jump; i >= 0; i--){
        if(par[u][i] != par[v][i]){
            u = par[u][i];
            v = par[v][i];
            
        }
    }
    
    return par[u][0];
    
    
    
    return 0;
    
    
} 



int main()
{
    scanf("%d%d%d",&n,&m,&s);
    
    for(int i = 1; i < n; i++ ){
        int u,v;
        scanf("%d%d",&u,&v);
        add(u,v);
        add(v,u);
        //par[v][0] = u;
        //par[u][0] = v;
    }

    
    deep[s] = 0;
    
    getdeep(s);
    
    getpar();
    
    for(int i = 1; i <= m; i++){
        int a,b;
        scanf("%d%d",&a,&b);
        printf("%d\n",lca(a,b));
        
    }
    
    //par[i][j] = par[par[i][j-1]][j-1]
    
        
    
    return 0;
}
View Code
相關文章
相關標籤/搜索