BZOJ 2648 世界树
题目传送门
分析:
喜 闻 乐 见 的虚树
但是建好虚树后的DP也非常的恶心
我们先考虑每个关键点的归哪个点管
先DFS一次计算儿子节点归属父亲
再DFS一次计算父亲节点归属儿子
然后然后我们对于虚树上的每条边计算一下
首先先找到分割点mid
那么向上归属的是红色部分
向下的是绿色部分
对于每条边都算一下就好了
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#define maxn 600005
#define INF 1ll<<50
using namespace std;
inline int getint()
{
int num=0,flag=1;char c;
while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
while(c>='0'&&c<='9')num=num*10+c-48,c=getchar();
return num*flag;
}
int n;
int fir[maxn],nxt[maxn],to[maxn],cnt;
int f[maxn][21],dpt[maxn],sz[maxn];
int h[maxn],a[maxn];
int bl[maxn],F[maxn],ans[maxn];
int In[maxn],Out[maxn],cur;
int stk[maxn],top;
int pos[maxn],tot;
inline bool cmp(int x,int y){return In[x]<In[y];}
inline void newnode(int u,int v)
{to[++cnt]=v,nxt[cnt]=fir[u],fir[u]=cnt;}
inline void dfs(int u,int fa)
{
sz[u]=1;In[u]=++cur;
for(int i=fir[u];i;i=nxt[i])if(to[i]!=fa)
{
f[to[i]][0]=u,dpt[to[i]]=dpt[u]+1;
dfs(to[i],u);
sz[u]+=sz[to[i]];
}
Out[u]=cur;
}
inline int LCA(int u,int v)
{
if(dpt[u]<dpt[v])swap(u,v);
int d=dpt[u]-dpt[v];
for(int i=20;~i;i--)if(d&(1<<i))u=f[u][i];
if(u==v)return u;
for(int i=20;~i;i--)if(f[u][i]!=f[v][i])u=f[u][i],v=f[v][i];
return f[u][0];
}
inline int getdis(int u,int v)
{return dpt[u]+dpt[v]-2*dpt[LCA(u,v)];}
inline void dfs1(int u,int fa)
{
pos[++tot]=u;F[u]=sz[u];
for(int i=fir[u];i;i=nxt[i])if(to[i]!=fa)
{
dfs1(to[i],u);
if(!bl[to[i]])continue;
if(!bl[u]){bl[u]=bl[to[i]];continue;}
int tmp1=getdis(bl[to[i]],u),tmp2=getdis(bl[u],u);
if(tmp1<tmp2||(tmp1==tmp2&&bl[to[i]]<bl[u]))bl[u]=bl[to[i]];
}
}
inline void dfs2(int u,int fa)
{
for(int i=fir[u];i;i=nxt[i])if(to[i]!=fa)
{
int tmp1=getdis(bl[u],to[i]),tmp2=getdis(bl[to[i]],to[i]);
if(tmp1<tmp2||(tmp1==tmp2&&bl[u]<bl[to[i]]))bl[to[i]]=bl[u];
dfs2(to[i],u);
}
}
inline void solve(int fa,int u)
{
int son=u,mid=u;
for(int i=20;~i;i--)if(dpt[fa]+(1<<i)<dpt[son])son=f[son][i];
F[fa]-=sz[son];
if(bl[fa]==bl[u]){ans[bl[fa]]+=sz[son]-sz[u];return;}
for(int i=20;~i;i--)
{
int tmp=f[mid][i];
if(dpt[tmp]<=dpt[fa])continue;
int tmp1=getdis(tmp,bl[fa]);
int tmp2=getdis(tmp,bl[u]);
if(tmp2<tmp1||(tmp1==tmp2&&bl[u]<bl[fa]))mid=tmp;
}
ans[bl[fa]]+=sz[son]-sz[mid];
ans[bl[u]]+=sz[mid]-sz[u];
}
inline void solve()
{
int K=getint();top=0;
int tt=K;
for(int i=1;i<=K;i++)a[i]=h[i]=getint(),bl[h[i]]=h[i];h[++K]=1;
sort(h+1,h+K+1,cmp);
for(int i=K-1;i;i--)h[++K]=LCA(h[i],h[i+1]);
sort(h+1,h+K+1,cmp);K=unique(h+1,h+K+1)-h-1;
stk[++top]=h[1];
for(int i=2;i<=K;i++)
{
while(top&&Out[stk[top]]<In[h[i]])top--;
newnode(stk[top],h[i]);
stk[++top]=h[i];
}
dfs1(h[1],h[1]),dfs2(h[1],h[1]);
for(int i=1;i<=K;i++)
for(int j=fir[pos[i]];j;j=nxt[j])
solve(pos[i],to[j]);
for(int i=1;i<=K;i++)ans[bl[pos[i]]]+=F[pos[i]];
for(int i=1;i<=tt;i++)printf("%d%c",ans[a[i]],(i==tt)?'\n':' ');
for(int i=1;i<=K;i++)F[h[i]]=ans[h[i]]=fir[h[i]]=bl[h[i]]=0;
cnt=tot=0;
}
int main()
{
n=getint();
for(int i=1;i<n;i++)
{
int u=getint(),v=getint();
newnode(u,v),newnode(v,u);
}
dfs(1,1);
for(int j=1;j<=20;j++)for(int i=1;i<=n;i++)f[i][j]=f[f[i][j-1]][j-1];
memset(fir,0,sizeof fir);cnt=0;
int m=getint();
while(m--)solve();
}
更多精彩