abcdeffa's Blog

当局者迷,旁观者清。

0%

GMOJ S3798 【临洮巨人】

Description

给你一个由 A,B,C,....,L 组成的字符串,统计有多少个子串中字符 A,B,C 数量相同。

为了方便表述,设字符串长度为 $n$ 。

Solution

为了方便表述,下文将用 $(l, r)$ 来表示由字符串中第 $l$ 个字符到第 $r$ 个字符按原顺序组成的子串。

考虑 $O(n^2)$ 的暴力,用前缀和即可。

设 $su(i, 0/1/2)$ 表示字符串中前 $i$ 个字符中 A/B/C 的个数。

那么对于一个子串 $(l, r)$ ,它若合法,则满足:

移项后可以得到:

若我们用 $A_i$ 来表示 $[su(i, 0) - su(i, 1)]$ ,且用 $B_i$ 来表示 $[su(i, 0) - su(i, 2)]$ ,那么这个式子可以被表示成:

若我们设 $f(i, j)$ 表示第 $k$ 个之前 $A_x = i$ 且 $B_x = j$ 的不同的 $x$ 的个数,那么我们可以得到以 $k$ 为右端点的合法子串的数量是 $f(A_k, B_k)$ 个,换句话说,就是有 $f(A_k, B_k)$ 个不同的 $l$ ,满足 $(l + 1, k)$ 是合法的。

但是我们发现 $A_k$ 和 $B_k$ 都很大,且有可能为负数,直接存的话空间复杂度是 $O(n^2)$ 的。

但是我们还知道最多只有 $n$ 个位置,不存在另一个位置使得这两个位置的 $A$ 和 $B$ 都相同。

所以 离散化 一下即可,时间复杂度 $O(n \log n)$ 。

此题还有 $O(n)$ 的哈希做法,交由读者练习。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#include <cstdio>
#include <cstring>
#define maxN 1000010
char st[maxN];
long long su[maxN][3], dq2[maxN], dq[maxN], a[maxN], b[maxN], c[maxN], d[maxN], f[maxN];
long long lendq = 0, lenc = 0, lend = 0;
void px (long long l, long long r, long long f[])
{
long long x = l, y = r, mid = f[(l + r) / 2];
while(x <= y)
{
while(f[x] < mid)
{
x++;
}
while(f[y] > mid)
{
y--;
}
if(x <= y)
{
long long t = f[x];
f[x] = f[y];
f[y] = t;
x++;
y--;
}
}
if(l < y)
{
px(l, y, f);
}
if(x < r)
{
px(x, r, f);
}
}
long long findc (long long x)
{
long long l = 1, r = lenc;
while(l < r)
{
long long mid = (l + r) / 2;
if(x > c[mid])
{
l = mid + 1;
}
else
{
r = mid;
}
}
return l;
}
long long findd (long long x)
{
long long l = 1, r = lend;
while(l < r)
{
long long mid = (l + r) / 2;
if(x > d[mid])
{
l = mid + 1;
}
else
{
r = mid;
}
}
return l;
}
long long finddq (long long x)
{
long long l = 1, r = lendq;
while(l < r)
{
long long mid = (l + r) / 2;
if(x > dq[mid])
{
l = mid + 1;
}
else
{
r = mid;
}
}
return l;
}
int main ()
{
scanf("%s", st + 1);
long long n = strlen(st + 1);
for(long long i = 1;i <= n; i++)
{
su[i][0] = su[i - 1][0];
su[i][1] = su[i - 1][1];
su[i][2] = su[i - 1][2];
if(st[i] >= 'A' && st[i] <= 'C')
{
su[i][st[i] - 'A']++;
}
}
for(long long i = 0;i <= n; i++)
{
a[i] = su[i][0] - su[i][1];
b[i] = su[i][0] - su[i][2];
}
px(0, n, a);
px(0, n, b);
c[++lenc] = a[0];
d[++lend] = b[0];
for(long long i = 1;i <= n; i++)
{
if(a[i] != a[i - 1])
{
c[++lenc] = a[i];
}
if(b[i] != b[i - 1])
{
d[++lend] = b[i];
}
}
for(long long i = 0;i <= n; i++)
{
dq2[i] = (findc(su[i][0] - su[i][1]) - 1) * lend + findd(su[i][0] - su[i][2]);
}
px(0, n, dq2);
dq[++lendq] = dq2[0];
for(long long i = 1;i <= n; i++)
{
if(dq2[i] != dq2[i - 1])
{
dq[++lendq] = dq2[i];
}
}
long long ans = 0;
for(long long i = 0;i <= n; i++)
{
long long x = finddq((findc(su[i][0] - su[i][1]) - 1) * lend + findd(su[i][0] - su[i][2]));
ans += f[x];
f[x]++;
}
printf("%lld", ans);
return 0;
}