[JZOJ5082].【GDSOI2017第三轮模拟】Informatics Training

来源:互联网 发布:proteus怎么仿真单片机 编辑:程序博客网 时间:2024/06/10 20:17

题目描述

这里写图片描述
操作数<=300000

分析

首先这是一道数据结构裸题。
面对这种码农题我们需要有较清晰的思路。
下面是我打题前的思路:
大概是用splay维护h,s,关键字就是编号。然后两个set维护其他东西。
T3
set:
mntr 维护每个splay的根的mn_s ,以及id ,就一个pair嘛
ntr 维护点编号
splay:
vtr维护

mn_h
mn_s
sum_h
siz
tag_s
tag_h

join 新建splay,点值s加入set1
practice , splay i , 减x,加y
discuss 给 splay i ,整体减x加y
判断体力是否低于0:
当前节点x低于, 那么左右子树merge,自己扔掉,
若左右子树中有mn低于0,那么继续递归
cooperate,去set2中找最大点和最小点,
test,set1中找到最小值,然后find(mn),并删除。
query直接输出。
注意每次操作完之前后,splay, set维护。
splay至根之前,删掉 mntr,旋完之后再加入
任何遍历都要update,down

然而还是因为不够清晰打错了不少地方,特别是由于mntr维护的是splay根们的信息,每次splay的时候要先删去,之后再加上。

其实这种题想到做法并不难,重点是要优化代码,像我这种三个数据结构,又臭又长。而王之栋的做法十分优美,只用了2个set。虽然我们都超时了····
而鞋垫的treap跑得飞起。
值得一提的是set的operator重定义,必须要把多个关键字都考虑进去,不然会find错····

代码

#include<cstdio>#include<algorithm>#include<cstring>#include<cmath>#include<set>using namespace std;typedef long long ll;typedef double db;#define fo(i,j,k) for(i=j;i<=k;i++)#define fd(i,j,k) for(i=j;i>=k;i--)const int N=300005;struct pi{    int val,id;    pi(int i_=0,int v_=0) {val=v_,id=i_;}}tmp;bool operator <(pi a,pi b) {    return a.val<b.val||a.val==b.val&&a.id<b.id;}struct rec{    int mn_h,mn_s,siz,h,s,tag_h;    ll sum,tag_s;}vtr[N];int tr[N][2],fa[N],st,sta[N];int n,x,y,z,t,l,m,kan,tp,k1;char ch;multiset<pi> mntr;set<int> :: iterator it;set<pi> :: iterator kk;set<int> ntr;void down(int x){    if (!x||(!vtr[x].tag_h&&!vtr[x].tag_s)) return;    vtr[x].mn_s+=vtr[x].tag_s;    vtr[x].s+=vtr[x].tag_s;    vtr[x].mn_h+=vtr[x].tag_h;    vtr[x].h+=vtr[x].tag_h;    vtr[x].sum+=1ll*vtr[x].tag_s*vtr[x].siz;    int i;    fo(i,0,1)        if (tr[x][i])        {            vtr[tr[x][i]].tag_s+=vtr[x].tag_s;            vtr[tr[x][i]].tag_h+=vtr[x].tag_h;        }    vtr[x].tag_h=vtr[x].tag_s=0;}void update(int x){    if (!x)         return ;    down(x);    down(tr[x][0]);    down(tr[x][1]);    vtr[x].mn_h=vtr[x].h;    vtr[x].mn_s=vtr[x].s;    vtr[x].siz=vtr[tr[x][0]].siz+vtr[tr[x][1]].siz+1;    vtr[x].sum=vtr[tr[x][0]].sum+vtr[tr[x][1]].sum+vtr[x].s;    int i;    fo(i,0,1)    if (tr[x][i])     {        vtr[x].mn_h=min(vtr[x].mn_h,vtr[tr[x][i]].mn_h);        vtr[x].mn_s=min(vtr[x].mn_s,vtr[tr[x][i]].mn_s);    }}int pd(int x){    return tr[fa[x]][1]==x;}void rotate(int x){    int y=fa[x],z=pd(x);    down(y);    down(x);    fa[x]=fa[y];    if (fa[y]) tr[fa[y]][pd(y)]=x;    tr[y][z]=tr[x][1-z];    if (tr[x][1-z]) fa[tr[x][1-z]]=y;    fa[y]=x;    tr[x][1-z]=y;    update(y);    update(x);}void downdate(int x,int y){    st=0;    while (x!=y)     {        sta[++st]=x;        x=fa[x];    }    if (!y)         mntr.erase(mntr.find(pi(sta[st],vtr[sta[st]].mn_s)));    while (st)    {        down(sta[st]);        st--;    }}void splay(int x,int y){    if (!x)  return;    downdate(x,y);    while (fa[x]!=y)    {        if (fa[fa[x]]!=y)        {            if (pd(x)==pd(fa[x])) rotate(fa[x]);            else rotate(x);        }        rotate(x);    }    if (!y) mntr.insert(pi(x,vtr[x].mn_s));}int go(int x,int sig){    while (tr[x][sig])     {        update(x);        x=tr[x][sig];    }    update(x);    return x;}int merge(int x,int y,int z,int t,int sig)// middle lson,rson, father{    if (sig&&x)        mntr.erase(mntr.find(pi(x,vtr[x].mn_s)));    y=go(y,1);    z=go(z,0);    if (y||z)    {        if (y) splay(y,x);        if (z) splay(z,x);        if (y)        {            if (z)             {                // erase z                if (sig&&!x)                     mntr.erase(mntr.find(pi(z,vtr[z].mn_s)));                fa[z]=y;            }            tr[y][1]=z;        }else y=z;        // erase y        if (sig&&!x)             mntr.erase(mntr.find(pi(y,vtr[y].mn_s)));        update(y);        // insert y        if (sig) mntr.insert(pi(y,vtr[y].mn_s));    }    if (t) tr[t][pd(x)]=y;    fa[y]=t;    if (t) update(t);    return y;}int del(int x,int sig)// remember to update fa[x], actually merge tr[x][0\1]{    it=ntr.find(x);    kan=*it;    ntr.erase(ntr.find(x));    int y=merge(x,tr[x][0],tr[x][1],fa[x],sig);    return y;}void Merge(int x,int y){    splay(x,0);    splay(y,0);    int z=x;    while (fa[z]!=0) z=fa[z];    if (z==y) return;    z=merge(0,x,y,0,1);    splay(z,0);}void join(){    scanf("%d %d",&x,&y);    n++;    ntr.insert(n);    vtr[n].siz=1;    vtr[n].mn_h=vtr[n].h=x;    vtr[n].sum=y;    vtr[n].mn_s=y;    vtr[n].s=y;    mntr.insert(pi(n,y));}void prac(){    scanf("%d %d %d",&x,&y,&z);    splay(x,0);    mntr.erase(mntr.find(pi(x,vtr[x].mn_s)));    vtr[x].h-=y;    vtr[x].s+=z;    update(x);    mntr.insert(pi(x,vtr[x].mn_s));    if (vtr[x].h<=0) del(x,1);}int dfs(int x)// 删除不合法的蒟蒻 {    down(x);    int y;    if (vtr[x].h<=0)    {        y=del(x,0);        if (y) return dfs(y);    }else    {        int i;        fo(i,0,1)            if (tr[x][i])            {                down(tr[x][i]);                if (vtr[tr[x][i]].mn_h<=0) dfs(tr[x][i]);            }        update(x);        return x;    }}void diss(){    scanf("%d %d %d",&x,&y,&z);    splay(x,0);    mntr.erase(mntr.find(pi(x,vtr[x].mn_s)));    vtr[x].tag_h-=y;    vtr[x].tag_s+=z;    update(x);    t=dfs(x);    if (t)     {        mntr.insert(pi(t,vtr[t].mn_s));        splay(t,0);    }}void elim(){    scanf("%d",&x);    splay(x,0);    del(x,1);}int find(int x,int y){    while (x)    {        update(x);        if (vtr[x].s==y) return x;        if (vtr[tr[x][0]].mn_s==y) x=tr[x][0];else        x=tr[x][1];    }}int main(){    freopen("training.in","r",stdin);    freopen("training.out","w",stdout);    scanf("%d\n",&m);    fo(l,1,m)    {        scanf("%c",&ch);        if (ch=='J') join();else        if (ch=='P') prac();else        if (ch=='D') diss();else        if (ch=='E') elim();else        if (ch=='M')        {            scanf("%d %d",&x,&y);            Merge(x,y);        }else if (ch=='C')        {            kan=ntr.size();            if (!kan) continue;            x=*ntr.begin();            it=ntr.end();            it--;            y=*it;            Merge(x,y);        }else if (ch=='T')        {            kan=ntr.size();            if (!kan) continue;            tmp=(*mntr.begin());            tp=tmp.val;            while (tmp.val==tp)            {                k1=find(tmp.id,tp);                splay(k1,0);                del(k1,1);                tmp=(*mntr.begin());            }        }else        {            scanf("%d",&x);            splay(x,0);            printf("%d %lld\n",vtr[x].siz,vtr[x].sum);        }        scanf("\n");    }}
0 0