花花的森林

来源:互联网 发布:世界历史地图软件 编辑:程序博客网 时间:2024/06/09 22:56

题目描述
花花有一棵带n 个顶点的树T,每个节点有一个点权ai。
有一天,他认为拥有两棵树更好一些。所以,他从T 中删去了一条边。
第二天,他认为三棵树或许又更好一些。因此,他又从他拥有的某一棵树中去除了一条边。
如此往复。每一天,花花都会删去一条尚未被删去的边,直到他得到了一个包含了n 棵只有一个点的树的森林。
定义一条简单路径的权值为路径上点权之和,一棵树的直径为树上权值最大的简单路径。
花花认为树最重要的特征就是它的直径。所以他想请你算出任一时刻他拥有的所有树的直径的乘积。因为这个数可能很大,他要求你输出乘积对109+7 取模之后的结果。

输入
输入的第一行包含一个整数n,表示树T 上顶点的数量。
下一行包含n 个空格分隔的整数ai,表示顶点的权值。
之后的n-1 行中,每一行包含两个用空格分隔的整数xi 和yi,表示节点xi 和yi 之间连有一条边,编号为i。
再之后n-1 行中,每一行包含一个整数kj,表示在第j 天里会被删除的边的编号
输出
输出n 行。
在第i 行,输出删除i-1 条边之后,所有树直径的乘积对10^9 + 7 取模的结果。

样例输入
3
1 2 3
1 2
1 3
2
1
样例输出
6
9
6
提示
初始,树的直径为6(由节点2、1 和3 构成的路径)。在第一天之后,得到了两棵直径都为3 的树。第二天之后,得到了三棵直径分别为1,2,3 的树,乘积为6。
• 对于40% 的数据:N<=100
• 另有20% 的数据:N<=1000
• 另有20% 的数据:N<=104
• 对于100% 的数据:N<=105;ai<=104

Solution

根据一贯的套路,这种删除边的问题很多都是倒着添边做的。
那我们就来考虑倒着做
最后树成了n个点的森林
每次合并两棵树,除去原来两棵树的直径,乘上合并后的新直径
除直径用乘逆元
那么现在的问题就是:如何快速得出两棵树合并后产生的新直径。
考试的时候,我这个渣渣当然不会了,于是我就暴力更新一条链,极限接近n2,但快得出奇,100000的随机数据都只需要0.8s
其实,关于直径又有一个套路,当两棵子树合并的时候,x树的直径两端点是a b,y树的直径两端点是c d,则新直径必定是abcd的某个点对。
求路径用LCA
结果我发现正解比暴力慢了1倍

暴力

#include<cstdio>#include<iostream>#include<algorithm>#include<cstring>#define ll long longusing namespace std;const ll mod=1e9+7;int n,x,y,u,tot;int st[100005],ed[100005],a[100005],cut[100005],f[100005];int s[100005],fa[100005],d[100005];int head[100005],Next[200005],to[200005];ll now,ans[100005];void dfs(int k,int pre){    fa[k]=pre;    for(int i=head[k];i!=-1;i=Next[i])     if(to[i]!=pre) dfs(to[i],k);}void add(int x,int y){    tot++;    Next[tot]=head[x];    to[tot]=y;    head[x]=tot;}int get(int x){    if(f[x]==x) return x; else return f[x]=get(f[x]);}void update(int k){    int s1=0,s2=0;    for(int i=head[k];i!=-1;i=Next[i])     if(fa[to[i]]==k)     {        if(s[to[i]]>s1)         {            s2=s1;            s1=s[to[i]];        }        else        if(s[to[i]]>s2) s2=s[to[i]];        d[k]=max(d[k],d[to[i]]);     }    s[k]=s1+a[k];    d[k]=max(d[k],s1+s2+a[k]);    if(k!=u) update(fa[k]); else return;}ll ny(ll x,ll y){    ll p=1;    while(y>0)     {        if(y%2==1) p=(p*x)%mod;        y=y/2;        x=(x*x)%mod;    }    return p;}void prepare(){    cin>>n;    for(int i=1;i<=n;i++)     {        head[i]=-1;        scanf("%d",&a[i]);     }    for(int i=1;i<n;i++)     {        scanf("%d%d",&st[i],&ed[i]);        add(st[i],ed[i]);        add(ed[i],st[i]);    }    dfs(1,0);    for(int i=1;i<n;i++) scanf("%d",&cut[i]);    tot=0;    for(int i=1;i<=n;i++)     {        f[i]=i;        d[i]=(ll)(a[i]);        s[i]=a[i];        head[i]=-1;    }    ans[n]=1;    for(int i=1;i<=n;i++) ans[n]=(ans[n]*a[i])%mod;    now=ans[n]; //现在的直径乘积 }   int main(){    prepare();    for(int i=n-1;i>=1;i--)     {        x=st[cut[i]],y=ed[cut[i]];        add(x,y);        add(y,x);        if(fa[x]==y)         {            u=get(y);            f[x]=u;            now=(now*ny((ll)(d[x]),mod-2))%mod;            now=(now*ny((ll)(d[u]),mod-2))%mod;             update(y);            now=(now*d[u])%mod;        }        else        {            u=get(x);            f[y]=u;            now=(now*ny((ll)(d[y]),mod-2))%mod;            now=(now*ny((ll)(d[u]),mod-2))%mod;             update(x);             now=(now*d[u])%mod;        }        ans[i]=now;    }    for(int i=1;i<=n;i++) printf("%lld\n",ans[i]);     return 0;}

正解

#include<cstdio>#include<iostream>#include<algorithm>#include<cstring>#define ll long longusing namespace std;const ll mod=1e9+7;int n,x,y,ty,i1,i2,dis,dad,tot;int u[100005],v[100005],b[5];int st[100005],ed[100005],a[100005],cut[100005];int fa[100005][20],d[100005],f[100005],s[100005];int head[100005],Next[200005],to[200005],deep[100005];ll now,ans[100005];void dfs(int k,int pre){    fa[k][0]=pre;    deep[k]=deep[pre]+1;    s[k]=s[pre]+a[k];    for(int i=head[k];i!=-1;i=Next[i])     if(to[i]!=pre) dfs(to[i],k);}void add(int x,int y){    tot++;    Next[tot]=head[x];    to[tot]=y;    head[x]=tot;}int get(int x){    if(f[x]==x) return x; else return f[x]=get(f[x]);}ll ny(ll x,ll y){    ll p=1;    while(y>0)     {        if(y%2==1) p=(p*x)%mod;        y=y/2;        x=(x*x)%mod;    }    return p;}void prepare(){    cin>>n;    for(int i=1;i<=n;i++)     {        head[i]=-1;        scanf("%d",&a[i]);     }    for(int i=1;i<n;i++)     {        scanf("%d%d",&st[i],&ed[i]);        add(st[i],ed[i]);        add(ed[i],st[i]);    }    dfs(1,0);    for(int i=1;(1<<i)<=n;i++)     for(int j=1;j<=n;j++) fa[j][i]=fa[fa[j][i-1]][i-1];    for(int i=1;i<n;i++) scanf("%d",&cut[i]);    tot=0;    for(int i=1;i<=n;i++)     {        f[i]=i;        u[i]=v[i]=i;        d[i]=(ll)(a[i]);        head[i]=-1;    }    ans[n]=1;    for(int i=1;i<=n;i++) ans[n]=(ans[n]*a[i])%mod;    now=ans[n]; //现在的直径乘积 }int LCA(int x,int y){    if(deep[x]<deep[y]) swap(x,y);    for(int i=18;i>=0;i--)     if(deep[x]-(1<<i)>=deep[y]) x=fa[x][i];    if(x==y) return x;    for(int i=18;i>=0;i--)     if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];    return fa[x][0];}int main(){    prepare();    for(int i=n-1;i>=1;i--)     {        x=st[cut[i]],y=ed[cut[i]];        add(x,y);        add(y,x);        if(fa[x][0]==y)         {            ty=get(y);            f[x]=ty;            b[1]=u[x],b[2]=v[x],b[3]=u[ty],b[4]=v[ty];            i1=i2=dis=0;            for(int j=1;j<=4;j++)             for(int k=1;k<=4;k++)             {                dad=LCA(b[j],b[k]);                if(s[b[j]]+s[b[k]]-2*s[dad]+a[dad]>dis)                 {                    dis=s[b[j]]+s[b[k]]-2*s[dad]+a[dad];                    i1=b[j];                    i2=b[k];                }            }            now=(now*ny((ll)(d[x]),mod-2))%mod;            now=(now*ny((ll)(d[ty]),mod-2))%mod;             d[ty]=dis,u[ty]=i1,v[ty]=i2;            now=(now*d[ty])%mod;        }        else        {            ty=get(x);            f[y]=ty;            b[1]=u[y],b[2]=v[y],b[3]=u[ty],b[4]=v[ty];            i1=i2=dis=0;            for(int j=1;j<=4;j++)             for(int k=1;k<=4;k++)             {                dad=LCA(b[j],b[k]);                if(s[b[j]]+s[b[k]]-2*s[dad]+a[dad]>dis)                 {                    dis=s[b[j]]+s[b[k]]-2*s[dad]+a[dad];                    i1=b[j];                    i2=b[k];                }            }            now=(now*ny((ll)(d[y]),mod-2))%mod;            now=(now*ny((ll)(d[ty]),mod-2))%mod;            d[ty]=dis,u[ty]=i1,v[ty]=i2;             now=(now*d[ty])%mod;        }        ans[i]=now;    }    for(int i=1;i<=n;i++) printf("%lld\n",ans[i]);     return 0;}
0 0
原创粉丝点击