Description
给你一个由 A,B,C,....,L
组成的字符串,统计有多少个子串中字符 A,B,C
数量相同。
为了方便表述,设字符串长度为 $n$ 。
Solution
为了方便表述,下文将用 $(l, r)$ 来表示由字符串中第 $l$ 个字符到第 $r$ 个字符按原顺序组成的子串。
考虑 $O(n^2)$ 的暴力,用前缀和即可。
设 $su(i, 0/1/2)$ 表示字符串中前 $i$ 个字符中 A/B/C
的个数。
那么对于一个子串 $(l, r)$ ,它若合法,则满足:
移项后可以得到:
若我们用 $A_i$ 来表示 $[su(i, 0) - su(i, 1)]$ ,且用 $B_i$ 来表示 $[su(i, 0) - su(i, 2)]$ ,那么这个式子可以被表示成:
若我们设 $f(i, j)$ 表示第 $k$ 个之前 $A_x = i$ 且 $B_x = j$ 的不同的 $x$ 的个数,那么我们可以得到以 $k$ 为右端点的合法子串的数量是 $f(A_k, B_k)$ 个,换句话说,就是有 $f(A_k, B_k)$ 个不同的 $l$ ,满足 $(l + 1, k)$ 是合法的。
但是我们发现 $A_k$ 和 $B_k$ 都很大,且有可能为负数,直接存的话空间复杂度是 $O(n^2)$ 的。
但是我们还知道最多只有 $n$ 个位置,不存在另一个位置使得这两个位置的 $A$ 和 $B$ 都相同。
所以 离散化 一下即可,时间复杂度 $O(n \log n)$ 。
此题还有 $O(n)$ 的哈希做法,交由读者练习。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
| #include <cstdio> #include <cstring> #define maxN 1000010 char st[maxN]; long long su[maxN][3], dq2[maxN], dq[maxN], a[maxN], b[maxN], c[maxN], d[maxN], f[maxN]; long long lendq = 0, lenc = 0, lend = 0; void px (long long l, long long r, long long f[]) { long long x = l, y = r, mid = f[(l + r) / 2]; while(x <= y) { while(f[x] < mid) { x++; } while(f[y] > mid) { y--; } if(x <= y) { long long t = f[x]; f[x] = f[y]; f[y] = t; x++; y--; } } if(l < y) { px(l, y, f); } if(x < r) { px(x, r, f); } } long long findc (long long x) { long long l = 1, r = lenc; while(l < r) { long long mid = (l + r) / 2; if(x > c[mid]) { l = mid + 1; } else { r = mid; } } return l; } long long findd (long long x) { long long l = 1, r = lend; while(l < r) { long long mid = (l + r) / 2; if(x > d[mid]) { l = mid + 1; } else { r = mid; } } return l; } long long finddq (long long x) { long long l = 1, r = lendq; while(l < r) { long long mid = (l + r) / 2; if(x > dq[mid]) { l = mid + 1; } else { r = mid; } } return l; } int main () { scanf("%s", st + 1); long long n = strlen(st + 1); for(long long i = 1;i <= n; i++) { su[i][0] = su[i - 1][0]; su[i][1] = su[i - 1][1]; su[i][2] = su[i - 1][2]; if(st[i] >= 'A' && st[i] <= 'C') { su[i][st[i] - 'A']++; } } for(long long i = 0;i <= n; i++) { a[i] = su[i][0] - su[i][1]; b[i] = su[i][0] - su[i][2]; } px(0, n, a); px(0, n, b); c[++lenc] = a[0]; d[++lend] = b[0]; for(long long i = 1;i <= n; i++) { if(a[i] != a[i - 1]) { c[++lenc] = a[i]; } if(b[i] != b[i - 1]) { d[++lend] = b[i]; } } for(long long i = 0;i <= n; i++) { dq2[i] = (findc(su[i][0] - su[i][1]) - 1) * lend + findd(su[i][0] - su[i][2]); } px(0, n, dq2); dq[++lendq] = dq2[0]; for(long long i = 1;i <= n; i++) { if(dq2[i] != dq2[i - 1]) { dq[++lendq] = dq2[i]; } } long long ans = 0; for(long long i = 0;i <= n; i++) { long long x = finddq((findc(su[i][0] - su[i][1]) - 1) * lend + findd(su[i][0] - su[i][2])); ans += f[x]; f[x]++; } printf("%lld", ans); return 0; }
|