点分治 + bitset + 乱搞

模板题当然是要乱搞的啦,这边提供一种使用点分治,bitset 的乱搞方法。

考虑需要在线查询,所以建出点分树。

设当前查询点为 $(u,v)$ ,其在点分树上的 LCA 为 $rt$ ,即在一次点分治中以 $rt$ 为分治重心找到处在不同子树内的 $u,v$。此时只需要将 $(u,rt),(v,rt)$ 的答案合并一下就可以了。这个答案可以用 bitset 预处理。

这样做时间复杂度是 $O(\frac{(m+n\log n)n}{w})$,问题不大。

然而空间是 $O(\frac{n^2\log n}{w})$ 的,很遗憾无法通过。

这个时候借用一下正解撒点的套路,如果我们只记录某些点的答案,中间暴力往 bitset 中加入答案就好了。

这边我直接按深度每隔 $8$ 个点打一次答案,这样查询的时候也就只需要向上跳 $8$ 次答案就能找到了。

时空复杂度是玄学的,因为很有可能某个深度的点非常多,也许再随机一点会更好。不过即使不开 O2 这样也能过。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include<bits/stdc++.h>
#define rep(a,b,c) for(int c(a);c<=(b);++c)
#define drep(a,b,c) for(int c(a);c>=(b);--c)
#define grep(b,c) for(int c(head[b]);c;c=nxt[c])
using namespace std;
inline int read()
{
int res=0;char ch=getchar();while(ch<'0'||ch>'9')ch=getchar();
while(ch<='9'&&ch>='0')res=res*10+(ch^48),ch=getchar();return res;
}
inline void print(int x){if(x>9)print(x/10);putchar(x%10+'0');}
const int N=4e4+10,INF=0x3f3f3f3f;int head[N],des[N<<1],nxt[N<<1],cgt;
inline void ins(int x,int y)
{
nxt[++cgt]=head[x];des[head[x]=cgt]=y;
nxt[++cgt]=head[y];des[head[y]=cgt]=x;
}
typedef unsigned long long uLL;
struct Bitset
{
uLL a[625];inline void operator |= (const Bitset &b){for(int i=0;i<625;++i)a[i]|=b.a[i];}
inline int count(){int ans=0;for(int i=0;i<625;++i)ans+=__builtin_popcountll(a[i]);return ans;}
inline void ins(int x){a[x>>6]|=1ull<<(x&63);}
inline void del(int x){a[x>>6]^=1ull<<(x&63);}
inline bool operator [] (int x){return a[x>>6]>>(x&63)&1;}
inline void clear(){memset(a,0,sizeof(a));}
};
unordered_map<int,int> Fa[N];
unordered_map<int,Bitset> s[N];int n,c[N],P[N],p;
inline int bnd(int x){int l(1),r(p),m;while(l<=r)P[m=(l+r)>>1]<=x?l=m+1:r=m-1;return l-1;}
int Mn,Rt,Sz,siz[N],fa[N],dep[N];
inline void getRt(int u,int fa=0)
{
siz[u]=1;int mx=0;
grep(u,i)if(!dep[des[i]]&&des[i]!=fa)
{
int v=des[i];getRt(v,u);
siz[u]+=siz[v];mx=max(mx,siz[v]);
}
mx=max(mx,Sz-siz[u]); mx<Mn?Mn=mx,Rt=u:0;
}
inline void getSz(int u,int fa=0){siz[u]=1;grep(u,i)if(!dep[des[i]]&&des[i]!=fa)getSz(des[i],u),siz[u]+=siz[des[i]];}
Bitset A;int RT;
inline void dfs(int u,int fa=0,int dd=1)
{
grep(u,i)if(des[i]!=fa&&!dep[des[i]])
{
bool flag=0;int v=des[i];
if(!A[c[v]])flag=true,A.ins(c[v]);
if(dd%8==0)
{
s[RT][v]=A;
Fa[RT][v]=u|(1<<30);
}
else Fa[Rt][v]=u;
dfs(v,u,dd+1);
if(flag)A.del(c[v]);
}
}
inline void Cnt(int u){s[u][u].ins(c[u]);A=s[u][u];Fa[u][u]=1<<30;RT=u;dfs(u);}
inline void Solve(int u)
{
Cnt(u);grep(u,i)if(!dep[des[i]])
{
int v=des[i];getSz(v);Mn=INF;Sz=siz[v];getRt(v);
fa[Rt]=u;dep[Rt]=dep[u]+1;Solve(Rt);
}
}
inline int LCA(int u,int v){while(u!=v){dep[u]<dep[v]?v=fa[v]:u=fa[u];}return u;}
Bitset ans;
inline int qry(int u,int v)
{
int lc=LCA(u,v);ans.clear();
while(true){int tmp=Fa[lc][u];if(tmp>>30&1){ans|=s[lc][u];break;}ans.ins(c[u]);u=tmp;}
while(true){int tmp=Fa[lc][v];if(tmp>>30&1){ans|=s[lc][v];break;}ans.ins(c[v]);v=tmp;}
return ans.count();
}
int main()
{
n=read();int Q=read();rep(1,n,i)P[i]=c[i]=read();
sort(P+1,P+n+1);p=unique(P+1,P+n+1)-P-1;rep(1,n,i)c[i]=bnd(c[i])-1;
rep(2,n,i)ins(read(),read());Sz=n;Mn=INF;getRt(1);dep[Rt]=1;Solve(Rt);
int lst=0;while(Q--){int u=lst^read(),v=read();print(lst=qry(u,v));putchar('\n');}
return 0;
}