poj3237--Tree(树链剖分+线段树)

来源:互联网 发布:在线音乐网站源码 编辑:程序博客网 时间:2024/06/02 15:54

题目链接:poj--3237

题意很简单,给出n个节点的一棵树,有三种操作:

1、C修改第i条边的值为v

2、N改变节点a到b内边的权值的符号(取反)

3、Q询问节点a到b内权值的最大值

首先树链剖分,将边整合到线段树上,线段树数组cl,因为存在取反操作,所以最大值可能是由最小值取反得到,所以记录最大和最小值,cl[i][0]记录第i段的最大值,cl[i][1]记录最小值,lazy做标记,该段是否取反,两次取反==没有取反

 

#include <cstdio>#include <cstring>#include <algorithm>using namespace std ;#define LL __int64#define INF 0x3f3f3f3f#define maxn 11000#define int_now int l,int r,int rt#define now l,r,rt#define lson l,(l+r)/2,rt<<1#define rson (l+r)/2+1,r,rt<<1|1struct node{    int u , v , w ;    int next ;} edge[maxn<<1] ;int head[maxn] , cnt , vis[maxn] ;int num[maxn] , dep[maxn] , fa[maxn] , son[maxn] , top[maxn] , w[maxn] , step ;int cl[maxn<<2][2] , lazy[maxn<<2] , n ;//cl[i][0]最大值,cl[i][1]最小值void add(int u,int v,int w) {    edge[cnt].u = u ; edge[cnt].v = v ; edge[cnt].w = w ;    edge[cnt].next = head[u] ; head[u] = cnt++ ;    return ;}void dfs1(int u) {    int i , v ;    num[u] = 1 ;    son[u] = -1 ;    for(i = head[u] ; i != -1 ; i = edge[i].next) {        v = edge[i].v ;        if( vis[v] ) continue ;        vis[v] = 1 ;        dep[v] = dep[u] + 1 ;        fa[v] = u ;        dfs1(v) ;        num[u] += num[v] ;        if( son[u] == -1 || ( num[son[u]] < num[v] ) )            son[u] = v ;    }    return ;}void dfs2(int u) {    if( son[u] == -1 ) return ;    top[son[u]] = top[u] ;    w[ son[u] ] = step++ ;    vis[ son[u] ] = 1 ;    dfs2(son[u]) ;    int i , v ;    for(i = head[u] ; i != -1 ; i = edge[i].next) {        v = edge[i].v ;        if( vis[v] ) continue ;        vis[v] = 1 ;        top[v] = v ;        w[v] = step++ ;        dfs2(v) ;    }    return ;}void dfs() {    memset(vis,0,sizeof(vis)) ;    vis[1] = 1 ; dep[1] = 1 ; fa[1] = -1 ;    dfs1(1) ;    memset(vis,0,sizeof(vis)) ;    vis[1] = 1 ; top[1] = 1 ; step = 1 ;    dfs2(1) ;    return ;}void swap1(int rt) {    cl[rt][0] = -cl[rt][0] ;    cl[rt][1] = -cl[rt][1] ;    swap(cl[rt][0],cl[rt][1]) ;}void push_up(int_now) {    cl[rt][0] = max(cl[rt<<1][0],cl[rt<<1|1][0]) ;    cl[rt][1] = min(cl[rt<<1][1],cl[rt<<1|1][1]) ;    if( lazy[rt] )        swap1(rt) ;    return ;}void push_down(int_now) {    if( lazy[rt] ) {        lazy[rt] = 0 ;        lazy[rt<<1] = 1 - lazy[rt<<1] ;        lazy[rt<<1|1] = 1 - lazy[rt<<1|1] ;        swap1(rt<<1) ;        swap1(rt<<1|1) ;    }    return ;}void update1(int k,int x,int_now) {    if( l == k && r == k ) {        lazy[rt] = 0 ;        cl[rt][0] = cl[rt][1] = x ;        return ;    }    push_down(now) ;    if( k <= (l+r)/2 )        update1(k,x,lson) ;    else        update1(k,x,rson) ;    push_up(now) ;    return ;}void update2(int ll,int rr,int_now) {    if( ll > r || rr < l ) return ;    if( ll <= l && rr >= r ) {        lazy[rt] = 1 - lazy[rt] ;        swap1(rt) ;        return ;    }    push_down(now) ;    update2(ll,rr,lson) ;    update2(ll,rr,rson) ;    push_up(now) ;    return ;}int query(int ll,int rr,int_now,int sum) {    if( ll > r || rr < l ) return -INF ;    if( ll <= l && rr >= r ) {        if( sum%2 ) return -cl[rt][1] ;        else return cl[rt][0] ;    }    sum += lazy[rt] ;    return max( query(ll,rr,lson,sum),query(ll,rr,rson,sum) ) ;}void solve(int u,int v,int k) {    int f1 , f2 , ans = -INF ;    while( u != v ) {        if( dep[u] > dep[v] )            swap(u,v) ;        f1 = top[u] ; f2 = top[v] ;        if( f1 == f2 ) {            if( k )                ans = max(ans,query(w[son[u]],w[v],1,step-1,1,0)) ;            else                update2(w[son[u]],w[v],1,step-1,1) ;            v = u ;        }        else if( dep[f1] > dep[f2] ) {            if( k )                ans = max(ans,query(w[f1],w[u],1,step-1,1,0)) ;            else                update2(w[f1],w[u],1,step-1,1) ;            u = fa[f1] ;        }        else{            if( k )                ans = max(ans,query(w[f2],w[v],1,step-1,1,0)) ;            else                update2(w[f2],w[v],1,step-1,1);            v = fa[f2] ;        }    }    if( ans == -INF ) ans = 0 ;    if( k )        printf("%d\n", ans) ;    return ;}char str[100] ;int main() {    int t , i , j , k ;    int u , v , s ;    scanf("%d", &t) ;    while( t-- ) {        memset(head,-1,sizeof(head)) ;        cnt = 0 ;        scanf("%d", &n) ;        for(i = 1 ; i < n ; i++) {            scanf("%d %d %d", &u, &v, &s) ;            add(u,v,s) ;            add(v,u,s) ;        }        dfs() ;        memset(cl,0,sizeof(cl)) ;        memset(lazy,0,sizeof(lazy)) ;        for(i = 0 ; i < n-1 ; i++) {            if( dep[ edge[i*2].u ] > dep[ edge[i*2].v ] )                swap(edge[i*2].u,edge[i*2].v  ) ;            update1(w[ edge[i*2].v ],edge[i*2].w,1,step-1,1) ;        }        while( scanf("%s", str) ) {            if( str[0] == 'D' ) break ;            if( str[0] == 'C' ) {                scanf("%d %d", &i, &k) ;                update1(w[ edge[(i-1)*2].v ],k,1,step-1,1) ;            }            else if( str[0] == 'N' ) {                scanf("%d %d", &i, &j) ;                solve(i,j,0) ;            }            else{                scanf("%d %d", &i, &j) ;                solve(i,j,1) ;            }        }    }    return 0 ;}


 

1 0