[BZOJ3611][Heoi2014]大工程 虚树+树形DP

原题链接
考前刷水调整心态.
每次给出k个点balabala,建虚树没跑了.
第一个问题统计一下每个子树中所有关键点到根节点的距离和,以及关键点数量,然后枚举路径转折点即可;
第二个第三个直接记录最值和次值,最后取最大/最小.
当前节点为关键点的情况简单判一下就可以了.
教训:多组询问时,注意所有初始化的细节.

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 1002333
#define M 5002333
using namespace std;
typedef long long ll;
const int inf=0x3bbbbbbb;
inline char nc()
{
    static char buf[100000],*p1,*p2;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline void read(int &x)
{
    x=0;char c=nc();
    while(!isdigit(c))c=nc();
    while(isdigit(c))x=x*10+c-'0',c=nc();
}
int n,m,head[N],to[M],nxt[M],val[M],tot;
inline void add(int x,int y,int z)
{
    to[++tot]=y;
    nxt[tot]=head[x];
    head[x]=tot;
    val[tot]=z;
}
int fa[N][21],Log[N],dfn[N],deep[N],tim;
void getfa(int x)
{
    dfn[x]=++tim,deep[x]=deep[fa[x][0]]+1;
    int i,y;
    for(i=head[x];i;i=nxt[i])if((y=to[i])!=fa[x][0])
    {
        fa[y][0]=x;
        getfa(y);
    }
}
void lca_init()
{
    getfa(1);
    int i,j;
    for(i=2;i<=n;++i)Log[i]=Log[i>>1]+1;
    for(i=1;i<=Log[n];++i)
        for(j=1;j<=n;++j)fa[j][i]=fa[fa[j][i-1]][i-1];
}
inline int lca(int x,int y)
{
    if(deep[x]<deep[y])swap(x,y);
    int i;
    for(i=Log[deep[x]];i>=0;--i)if(deep[fa[x][i]]>=deep[y])
        x=fa[x][i];
    if(x==y)return x;
    for(i=Log[deep[x]];i>=0;--i)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
int k,p[N],v[N],now,z[N],top,junk;
int max1[N],max2[N],min1[N],min2[N],num[N],vis[N];
ll sum[N];
int ans1,ans2;
ll ans3;
void getans(int x,int pre)
{
    num[x]=(vis[x]==now),sum[x]=0;
    int i,y;
    max1[x]=max2[x]=-inf,min1[x]=min2[x]=inf;
    if(num[x])min1[x]=max1[x]=0;
    for(i=head[x];i>junk;i=nxt[i])if((y=to[i])!=pre)
    {
        getans(y,x);
        num[x]+=num[y],sum[x]+=sum[y]+(ll)num[y]*val[i];
        if(max1[y]+val[i]>max1[x])max2[x]=max1[x],max1[x]=max1[y]+val[i];
        else max2[x]=max(max2[x],max1[y]+val[i]);
        if(min1[y]+val[i]<min1[x])min2[x]=min1[x],min1[x]=min1[y]+val[i];
        else min2[x]=min(min2[x],min1[y]+val[i]);
    }
    ans1=max(ans1,max1[x]+max2[x]);
    ans2=min(ans2,min1[x]+min2[x]);
    // printf("x=%d sum=%lld num=%d\n",x,sum[x],num[x]);
    for(i=head[x];i>junk;i=nxt[i])if((y=to[i])!=pre)
    {
        ans3+=(sum[y]+(ll)num[y]*val[i])*(num[x]-num[y]);
        // printf("ans3+=%lld\n",(sum[y]+(ll)num[y]*val[i])*(num[x]-num[y]));
    }
}
inline bool cmp(int x,int y)
{
    return dfn[x]<dfn[y];
}
inline void query()
{
    //build a tree
    now+=2,junk=tot;
    int i,cnt=k,x;
    for(i=1;i<=k;++i)vis[p[i]]=now;
    sort(p+1,p+k+1,cmp);
    for(i=1;i<k;++i)
    {
        x=lca(p[i],p[i+1]);
        if(vis[x]!=now&&vis[x]!=now+1)vis[x]=now+1,p[++cnt]=x;
    }
    sort(p+1,p+cnt+1,cmp);
    z[top=1]=p[1];
    for(i=2;i<=cnt;++i)
    {
        while(lca(p[i],z[top])!=z[top])top--;
        add(z[top],p[i],deep[p[i]]-deep[z[top]]),z[++top]=p[i];
    }
    //get the answer
    ans1=-inf,ans2=inf,ans3=0;
    getans(p[1],0);
    printf("%lld %d %d\n",ans3,ans2,ans1);
}
int main()
{
    read(n);
    int i,j,x,y;
    for(i=1;i<n;++i)
    {
        read(x),read(y);
        add(x,y,0),add(y,x,0);
    }
    lca_init();
    read(m);
    for(i=1;i<=m;++i)
    {
        read(k);
        for(j=1;j<=k;++j)read(p[j]);
        query();
    }
    return 0;
}
/*
7
1 2
2 3
2 4
1 5
5 6
6 7

4
2
2 3
3
4 7 5
4
3 2 1 7
3
3 4 2

*/

发表评论

电子邮件地址不会被公开。 必填项已用*标注