abcdeffa's Blog

当局者迷,旁观者清。

0%

GMOJ S6904 【树上询问】

Description

给你一棵有 $n$ 个点的树,给你 $m$ 个询问,每次问你对于一组给定的 $l_i$ 和 $r_i$,有多少个整数 $k$ 满足从点 $l_i$ 出发沿着从点 $l_i$ 到点 $r_i$ 的简单路径上走 $k$ 步恰好走到点 $k$。

单个测试点的时间限制为 3s。

Solution

有趣但有些套路的题目,比赛的时候降智了没想出来,于是写篇题解来纪念一下。

为了方便表述,在此令 $u = l_i, v = r_i, f = \text{lca}(u, v)$,其中 $\text{lca}(u, v)$ 表示的是点 $u$ 和点 $v$ 的最近公共祖先,下文提到的路径均为简单路径。

考虑点 $k$ 在点 $u$ 到点 $f$ 这一段上的情况,设 $\text{dep} _i$ 表示点 $i$ 的深度,则合法的点 $k$ 需满足:

移项后可以得到:

于是式子的左边对于一个询问来说是定值,式子的右边本身就是一个定值。

于是问题转换为了求一段路径上点权(等于 $(k + \text{dep}_k)$)等于 $\text{dep}_u$ 的点的数量,将询问离线后用树上差分即可做到。

对于点 $k$ 在点 $f$ 到点 $v$ 这一段上的情况,处理方法类似,就不在此赘述了。

注意不能将点 $f$ 同时归在两种情况里面,随便归类到一种情况中即可。

维护一个点上挂着的询问起始点和结束点的时候最好用链式结构维护,用 $\text{map}$ 的话不仅多带一个 $\log$,常数还大,过不去。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#include <cstdio>
#include <cstdlib>
#define maxN 600010
struct Edge{ int x, y, g; } b[maxN << 1];
int len1 = 0, len2 = 0, len3 = 0, len4 = 0, len = 0;
int f[maxN][20];
int Ans[maxN], dep[maxN], h[maxN];
int h1[maxN], h2[maxN], h3[maxN], h4[maxN];
int last1[maxN], last2[maxN], last3[maxN], last4[maxN];
int begin1[maxN], end1[maxN], begin2[maxN], end2[maxN];
int val1[maxN], val2[maxN], cnt1[maxN], cnt2[maxN], p1[maxN], p2[maxN];
int read ()
{
int x = 0;
char c = getchar();
while(c < '0' || c > '9')
{
c = getchar();
}
while(c >= '0' && c <= '9')
{
x = x * 10 + (c - '0');
c = getchar();
}
return x;
}
void ins (int x, int y)
{
len++;
b[len].x = x;
b[len].y = y;
b[len].g = h[x];
h[x] = len;
}
void Insert_begin1 (int pos, int x)
{
len1++;
begin1[len1] = x;
last1[len1] = h1[pos];
h1[pos] = len1;
}
void Insert_begin2 (int pos, int x)
{
len2++;
begin2[len2] = x;
last2[len2] = h2[pos];
h2[pos] = len2;
}
void Insert_end1 (int pos, int x)
{
len3++;
end1[len3] = x;
last3[len3] = h3[pos];
h3[pos] = len3;
}
void Insert_end2 (int pos, int x)
{
len4++;
end2[len4] = x;
last4[len4] = h4[pos];
h4[pos] = len4;
}
void dfs (int x)
{
for(int i = h[x];i;i = b[i].g)
{
int y = b[i].y;
if(!dep[y])
{
f[y][0] = x;
dep[y] = dep[x] + 1;
dfs(y);
}
}
}
int lca (int x, int y)
{
if(dep[x] < dep[y])
{
int t = x;
x = y;
y = t;
}
for(int i = 19;i >= 0; i--)
{
if(f[x][i] && dep[f[x][i]] >= dep[y])
{
x = f[x][i];
}
}
if(x == y)
{
return x;
}
for(int i = 19;i >= 0; i--)
{
if(f[x][i] && f[y][i] && f[x][i] != f[y][i])
{
x = f[x][i];
y = f[y][i];
}
}
return f[x][0];
}
void work (int x)
{
cnt1[val1[x]]++;
cnt2[val2[x]]++;
if(h1[x])
{
for(int i = h1[x];i;i = last1[i])
{
int now = begin1[i];
Ans[now] -= cnt1[p1[now]];
}
}
if(h2[x])
{
for(int i = h2[x];i;i = last2[i])
{
int now = begin2[i];
Ans[now] -= cnt2[p2[now]];
}
}
if(h3[x])
{
for(int i = h3[x];i;i = last3[i])
{
int now = end1[i];
Ans[now] += cnt1[p1[now]];
}
}
if(h4[x])
{
for(int i = h4[x];i;i = last4[i])
{
int now = end2[i];
Ans[now] += cnt2[p2[now]];
}
}
for(int i = h[x];i;i = b[i].g)
{
int y = b[i].y;
if(dep[y] == dep[x] + 1)
{
work(y);
}
}
cnt1[val1[x]]--;
cnt2[val2[x]]--;
}
int main ()
{
int n = read(), m = read();
for(int i = 1;i < n; i++)
{
int x = read(), y = read();
ins(x, y), ins(y, x);
}
int root = 1;
dep[root] = 1;
dfs(root);
for(int j = 1;j <= 19; j++)
{
for(int i = 1;i <= n; i++)
{
f[i][j] = f[f[i][j - 1]][j - 1];
}
}
for(int i = 1;i <= n; i++)
{
val1[i] = i + dep[i];
val2[i] = i - dep[i];
}
for(int i = 1;i <= m; i++)
{
int x = 0, y = 0;
scanf("%d %d", &x, &y);
int LCA = lca(x, y);
p1[i] = dep[x];
p2[i] = dep[x] - 2 * dep[LCA];
Insert_begin1(f[LCA][0], i);
Insert_begin2(LCA, i);
Insert_end1(x, i);
Insert_end2(y, i);
}
work(root);
for(int i = 1;i <= m; i++)
{
printf("%d\n", Ans[i]);
}
return 0;
}