abcdeffa's Blog

当局者迷,旁观者清。

0%

GMOJ S6893 【小 T 与灵石】

Description

给你一棵有 $n$ 个点的树,边权均为 1。再给你 $m$ 个点集,第 $i$ 个点集有 $k_i$ 个点,定义 $f_{x, i}$ 表示点 $x$ 到第 $i$ 个点集中所有点的最大距离。

定义 $g_x$ 表示 $\min_{i = 1}^{m} f_{x, i}$,试求 $g_1$ ~ $g_n$。

Solution

发现对于一个点集,如果我们称其中距离最远的点对为这个点集的直径的两端,那么一个点到这个点集中所有点的最大距离一定是到中点的距离加上直径长度的一半。

对于中点在一个点上的情况,我们把直径的长度 $d$ 的一半当作这个点的点权,如果有多次赋值则取最小值,并标记一下它。

对于中点在一条边上的情况,我们把 $\lceil \dfrac{d}{2} \rceil$ 给到中点所在的边的两端的点,并标记一下它们。

然后我们考虑怎么求出一个点集的直径。

当 $k = 1$ 时,直径退化为一个点。

当 $k \geq 2$ 时,考虑先钦定点集中的前两个数为直径的两端 $l$ 和 $r$,然后看对于一个新点 $x$,它能否成为当前直径的其中一段,这个看一下 $x$ 到 $l$ 和 $r$ 的距离比不比当前直径优就好了。

然后问题就转化为了求所有点到标记点的最短距离,这个可以用换根 DP 或最短路算法(如 SPFA)解决。

求解树上两点间距离建议采用树链剖分而非倍增 LCA,因为前者在实际表现上更优秀(如对于一条链的情况),而后者我卡不过去。

由于 OJ 的系统栈较小,应采用 BFS 而非 DFS。

官方题解提供了一种新建点然后连向直径中点,边权为 $\dfrac{d}{2}$ 的做法,并为了避免小数,将边权乘 2。若在边上则拆去原有边,再开一个新点,建两条边权为 1 的边,连向刚刚删除边的两端即可。

我试着写了一下这种做法,但感觉比较复杂?

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
#include <cstdio>
#include <cstring>
#define maxN 900010
struct Edge{ int x, y, g; } b[maxN << 1];
int len = 1, n = 0, cnt = 0;
int f[maxN][20];
int top[maxN], size[maxN], son[maxN], val[maxN], dep[maxN], dl[maxN], fa[maxN], h[maxN], d[maxN], v[maxN], a[maxN];
int read ()
{
int x = 0;
char c = getchar();
while(c < '0' || c > '9')
{
c = getchar();
}
while(c >= '0' && c <= '9')
{
x = x * 10 + (c - '0');
c = getchar();
}
return x;
}
int min (int x, int y)
{
return x < y ? x : y;
}
void ins (int x, int y)
{
len++;
b[len].x = x;
b[len].y = y;
b[len].g = h[x];
h[x] = len;
len++;
b[len].x = y;
b[len].y = x;
b[len].g = h[y];
h[y] = len;
}
void work (int S)
{
int head = 1, tail = 2;
dl[head] = S;
while(head < tail)
{
int x = dl[head++];
for(int i = h[x];i;i = b[i].g)
{
int y = b[i].y;
if(!dep[y])
{
fa[y] = x;
f[y][0] = x;
dep[y] = dep[x] + 1;
dl[tail++] = y;
}
}
}
for(int i = tail - 1;i > 1; i--)
{
int x = dl[i];
size[x]++;
size[fa[x]] += size[x];
if(size[x] > size[son[fa[x]]])
{
son[fa[x]] = x;
}
}
}
int dis (int x, int y)
{
int X = x, Y = y;
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]])
{
int t = x;
x = y;
y = t;
}
x = fa[top[x]];
}
if(dep[x] > dep[y])
{
int t = x;
x = y;
y = t;
}
return dep[X] + dep[Y] - 2 * dep[x];
}
void bfs ()
{
int head = 1, tail = 1;
for(int i = 1;i <= cnt; i++)
{
d[i] = val[i], v[i] = 0;
if(val[i] < 707406378)
{
dl[tail++] = i, v[i] = 1;
}
}
while(head < tail)
{
int x = dl[head++];
for(int i = h[x];i;i = b[i].g)
{
int y = b[i].y;
if(d[x] + 1 < d[y])
{
d[y] = d[x] + 1;
if(!v[y])
{
v[y] = 1;
dl[tail++] = y;
}
}
}
v[x] = 0;
}
}
void solve (int S)
{
int head = 1, tail = 2;
dl[head] = S;
while(head < tail)
{
int x = dl[head++];
if(son[x])
{
top[son[x]] = top[x];
dl[tail++] = son[x];
}
for(int i = h[x];i;i = b[i].g)
{
int y = b[i].y;
if(y != son[x] && dep[y] == dep[x] + 1)
{
top[y] = y;
dl[tail++] = y;
}
}
}
}
int main ()
{
cnt = n = read();
for(int i = 2;i <= n; i++)
{
ins(read(), i);
}
int root = 1;
dep[root] = 1;
work(root);
for(int j = 1;j <= 19; j++)
{
for(int i = 1;i <= n; i++)
{
f[i][j] = f[f[i][j - 1]][j - 1];
}
}
top[root] = root;
solve(root);
memset(val, 127 / 3, sizeof(val));
int Q = read();
while(Q--)
{
int k = read(), D = 0, l = 0, r = 0;
for(int i = 1;i <= k; i++)
{
a[i] = read();
}
if(k == 1)
{
val[a[1]] = 0;
continue;
}
D = dis(a[1], a[2]), l = a[1], r = a[2];
for(int i = 3;i <= k; i++)
{
int ll = l, rr = r;
int now = dis(l, a[i]);
if(now > D)
{
D = now;
ll = l;
rr = a[i];
}
now = dis(a[i], r);
if(now > D)
{
D = now;
ll = a[i];
rr = r;
}
l = ll, r = rr;
}
if(dep[l] < dep[r])
{
int t = l;
l = r;
r = t;
}
int res = 0;
const int Dist = D / 2;
for(int i = 19;i >= 0; i--)
{
if(f[l][i] && res + (1 << i) <= Dist)
{
res += (1 << i);
l = f[l][i];
}
}
if(D & 1)
{
val[l] = min(val[l], (D + 1) / 2);
val[fa[l]] = min(val[fa[l]], (D + 1) / 2);
}
else
{
val[l] = min(val[l], D / 2);
}
}
bfs();
for(int i = 1;i <= n; i++)
{
printf("%d\n", d[i]);
}
return 0;
}