lca最近公共祖先(st表/倍增)

来源:互联网 发布:java软件工程培训 编辑:程序博客网 时间:2024/06/11 15:10
大体思路

1.求出每个元素在树中的深度

2.用st表预处理的方法处理出f[i][j],f[i][j]表示元素i上方第2^j行对应的祖先是谁 

3.将较深的点向上挪,直到两结点的深度相同

4.深度相同后,祖先可能就在上方,再走几步就到了,于是两个点同时向上移

具体的方法和代码贴在下面 ↓

具体

1.求出每个元素在树中的深度

//求每个节点在树中的深度void dfs(int pos,int pre)//pre是pos的父节点 {    for(int i=0;i<v[pos].size;i++)//枚举pos的子节点     {        register int t=v[pos][i];        if(t==pre)continue;//防止死循环         f[t][0]=pos;dep[t]=dep[pos]+1;        dfs(t,pos);//再从子节点向后枚举     }}

2.用st表预处理的方法处理出f[i][j]

//求f数组(st表预处理) for(int i=1;(1<<i)<=n;i++)    for(int j=1;j<=n;j++)        f[j][i]=f[f[j][i-1]][i-1];//f[i][j]表示元素i上方第2^j行对应的祖先是谁 

3.先比较a,b两点哪个较深,将较深的点赋值给a

//把a节点变为a,b中较深的一个节点 int query(int a,int b){    if(dep[a]<dep[b])swap(a,b);}

将较深的点向上挪,直到两结点的深度相同

//使a和b在同一个深度上 for(int i=20;i>=0;i--)    if(dep[f[a][i]]>=dep[b])        a=f[a][i];//倒着循环是因为向上走的步数只会越来越小 

4.深度相同后,祖先可能就在上方,再走几步就到了,于是两个点同时向上移

//同一深度后,再向上找公共祖先 for(int i=20;i>=0;i--)    if(f[a][i]!=f[b][i])    {        a=f[a][i];        b=f[b][i];     } 

 全部代码

#include <cstdio>#include <cstring>#include <iostream>#include <vector>using namespace std;vector<int> v[41000];vector<int> w[41000];int f[41000][21];//f[i][j]表示i点向上2^j层的祖先 int g[41000][21];//g[i][j]表示i点到从i向上2^j层的祖先的距离 int dep[41000];int n,m;void dfs(int pos,int pre,int depth){    dep[pos]=depth;    for(int i=0;i<v[pos].size();i++)    {        int t=v[pos][i];        if(t==pre) continue;        f[t][0]=pos;        g[t][0]=w[pos][i];        dfs(t,pos,depth+1);    }}int query(int a,int b){    int sum=0;    if(dep[a]<dep[b]) swap(a,b);//深度较深的点     for(int i=20;i>=0;i--)//找到a在深度dep[b]处的祖先     {        if(dep[f[a][i]]>=dep[b])        {            sum+=g[a][i];//a到该祖先的距离             a=f[a][i];        }    }    if(a==b) return sum;//挪到相同深度后如果在同一点直接return     int x;    for(int i=20;i>=0;i--)//否则a和b一起往上蹦跶     {        if(f[a][i]!=f[b][i])        {            sum+=g[a][i];            a=f[a][i];            sum+=g[b][i];            b=f[b][i];        }    }    return sum+g[a][0]+g[b][0];//最后蹦跶到最近公共祖先的下一层,所以要再加上上一层 }int main(){    int T;    cin>>T;    while(T--)    {        scanf("%d%d",&n,&m);        memset(dep,-1,sizeof dep);//多组数据我们初始化         memset(f,0,sizeof f);        memset(g,0,sizeof g);        for(int i=0;i<n;i++)//md            v[i].clear(),w[i].clear();        for(int i=1;i<n;i++)        {            int x,y,c;            cin>>x>>y>>c;            v[x].push_back(y);            w[x].push_back(c);            v[y].push_back(x);            w[y].push_back(c);        }        int xxx=v[1].size();        dfs(1,0,1);//dfs处理出每个点的深度,以及各种...             for(int i=1;1<<i<=n;i++)            for(int j=1;j<=n;j++)                f[j][i]=f[f[j][i-1]][i-1],                g[j][i]=g[f[j][i-1]][i-1]+g[j][i-1];        for(int i=1;i<=m;i++)        {            int x,y;            cin>>x>>y;            if(x==y) cout<<"0"<<endl;            else cout<<query(x,y)<<endl;        }    }    return 0;}

 

0 0
原创粉丝点击