abcdeffa's Blog

当局者迷,旁观者清。

0%

GMOJ S5257 【小X的佛光】

Description

小 X 作为一名远近闻名、热爱数学的学佛、男孩子。他给出了你一个有 $n$ 个点的树,以及 $m$ 个询问,每次询问你从点 $a_i$ 到点 $b_i$ 的简单路径和从点 $c_i$ 到点 $b_i$ 的简单路径有多少个公共点。

Solution

猜了个结论:答案为这三个点中 LCA 深度最大的那个 LCA 到 $b_i$ 的距离。

简单地把各种情况都试了一下发现没有问题就用了。

因为 OJ 的系统栈比较小,所以遍历树的时候要用 BFS。

时间复杂度 $O(m \log n)$。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#include <cstdio>
#define maxN 2000010
struct edge{ int x, y, g; } b[maxN << 1];
int len = 0;
int f[maxN][20], w[maxN][20];
int dep[maxN], dl[maxN], h[maxN];
int max (int x, int y)
{
return x > y ? x : y;
}
void ins (int x, int y)
{
len++;
b[len].x = x;
b[len].y = y;
b[len].g = h[x];
h[x] = len;
}
void bfs (int x)
{
int l = 1, r = 2;
dl[1] = x;
while(l < r)
{
int x = dl[l];
for(int i = h[x];i;i = b[i].g)
{
int y = b[i].y;
if(!dep[y])
{
f[y][0] = x;
w[y][0] = 1;
dep[y] = dep[x] + 1;
dl[r] = y;
r++;
}
}
l++;
}
}
int lca (int x, int y)
{
if(dep[x] < dep[y])
{
int t = x;
x = y;
y = t;
}
for(int i = 19;i >= 0; i--)
{
if(f[x][i] && dep[f[x][i]] >= dep[y])
{
x = f[x][i];
}
}
if(x == y)
{
return x;
}
for(int i = 19;i >= 0; i--)
{
if(f[x][i] && f[y][i] && f[x][i] != f[y][i])
{
x = f[x][i];
y = f[y][i];
}
}
return f[x][0];
}
int dis (int x, int y)
{
if(dep[x] < dep[y])
{
int t = x;
x = y;
y = t;
}
int ans = 0;
for(int i = 19;i >= 0; i--)
{
if(f[x][i] && dep[f[x][i]] >= dep[y])
{
ans += w[x][i];
x = f[x][i];
}
}
if(x == y)
{
return ans;
}
for(int i = 19;i >= 0; i--)
{
if(f[x][i] && f[y][i] && f[x][i] != f[y][i])
{
ans += w[x][i];
ans += w[y][i];
x = f[x][i];
y = f[y][i];
}
}
ans += w[x][0];
ans += w[y][0];
return ans;
}
int main ()
{
int n = 0, T = 0;
scanf("%d %d %*d", &n, &T);
for(int i = 1;i < n; i++)
{
int x = 0, y = 0;
scanf("%d %d", &x, &y);
ins(x, y);
ins(y, x);
}
int root = 1;
dep[root] = 1;
bfs(root);
for(int j = 1;j <= 19; j++)
{
for(int i = 1;i <= n; i++)
{
f[i][j] = f[f[i][j - 1]][j - 1];
w[i][j] = w[f[i][j - 1]][j - 1] + w[i][j - 1];
}
}
while(T--)
{
int A = 0, B = 0, C = 0;
scanf("%d %d %d", &A, &B, &C);
int fA = lca(A, B);
int fB = lca(A, C);
int fC = lca(B, C);
int depA = dep[fA];
int depB = dep[fB];
int depC = dep[fC];
int depMax = max(depA, max(depB, depC));
if(depA == depMax)
{
printf("%d\n", dis(fA, B) + 1);
}
else if(depB == depMax)
{
printf("%d\n", dis(fB, B) + 1);
}
else if(depC == depMax)
{
printf("%d\n", dis(fC, B) + 1);
}
}
return 0;
}