Description
你有 $n$ 个二元组,你可以将 $(a_1, b_1)$ 与 $(a_2, b_2)$ 合并成:
,需要的花费为 $(ka + b_1 + b_2)$。
现在你有 $n$ 个二元组 $(a_1, 0), (a_2, 0), . . . ,(a_n, 0)$,你需要按照一个排列 $p_1, p_2, p_3, . . . , p_n$ 的顺序合并,就是先将 $(a_{p_1} , 0)$ 与 $(a_{p_2} , 0)$ 合并,再将所得结果与 $(a_{p_3} , 0)$ 合并,以此类推。
你希望最后结果 $(a, b)$ 中 $a$ 最大,并希望求出在 $a$ 最大的条件下所有合法排列的合并代价总和对 $(10^9 + 7)$ 取模后的结果。
时间限制 2s,空间限制 512MB。
Solution
看到 $n$ 这么小,考虑状压 DP。
设 $f_{S, i}$ 表示当前已经合并了的状态为 $S$,结果的 $a$ 值为 $i$ 的代价和。
因为 $1 \leq a_i \leq 10^9$,于是我们考虑把所有可能合并出来的数求出来,然后按照这个对 $a_i$ 进行离散化即可,由于显然这些数的总数小于 $2n$,然后暴力转移即可。
比赛的时候因为数组开小和没有取模而 FST 了,出题人实在是高。
据说有 $O(n^2)$ 做法,然而我并不会。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
| #include <map> #include <cstdio> #include <cstring> #define maxN 18 #define mod 1000000007 std::map<int, int> find, cnt; long long a[maxN + 1], b[maxN + 1], c[maxN + 1]; long long count[1 << maxN], f[1 << maxN][37], js[1 << maxN][37]; long long max (long long x, long long y) { return x > y ? x : y; } void px1 (long long l, long long r) { long long x = l, y = r, mid = a[(l + r) >> 1]; while(x <= y) { while(a[x] < mid) { x++; } while(a[y] > mid) { y--; } if(x <= y) { long long t = a[x]; a[x] = a[y]; a[y] = t; x++; y--; } } if(l < y) { px1(l, y); } if(x < r) { px1(x, r); } } void px2 (long long l, long long r) { long long x = l, y = r, mid = b[(l + r) >> 1]; while(x <= y) { while(b[x] < mid) { x++; } while(b[y] > mid) { y--; } if(x <= y) { long long t = b[x]; b[x] = b[y]; b[y] = t; x++; y--; } } if(l < y) { px2(l, y); } if(x < r) { px2(x, r); } } int main () { long long n = 0, k = 0; scanf("%lld %lld", &n, &k); for(long long i = 1;i <= n; i++) { scanf("%lld", &a[i]); b[i] = a[i]; cnt[a[i]]++; } px1(1, n); b[0] = n; long long Ans1 = 0, Ans2 = 0; for(long long i = 1;i <= n; i++) { if(cnt[a[i]] >= 2) { if(!cnt[a[i] + 1]) { b[++b[0]] = a[i] + 1; } cnt[a[i] + 1]++; } } px2(1, b[0]); Ans1 = b[b[0]]; c[++c[0]] = b[1]; for(long long i = 2;i <= b[0]; i++) { if(b[i] != b[i - 1]) { c[++c[0]] = b[i]; } } for(long long i = 1;i <= c[0]; i++) { find[c[i]] = i; } memset(f, -127 / 3, sizeof(f)); for(long long i = 1;i <= n; i++) { f[1 << (i - 1)][find[a[i]]] = 0; js[1 << (i - 1)][find[a[i]]] = 1; } const long long Ma = (1 << n) - 1; for(long long S = 0;S <= Ma; S++) { long long x = S; while(x) { count[S] += (x & 1); x >>= 1; } } for(long long S = 0;S <= Ma; S++) { for(long long i = 1;i <= n; i++) { if(S & (1 << (i - 1))) { continue; } for(long long x = 1;x <= 36; x++) { if(f[S][x] < 0) { continue; } long long dq = find[a[i]]; long long now = max(x, dq) + (x == dq); if(f[S | (1 << (i - 1))][now] < 0) { f[S | (1 << (i - 1))][now] = 0; } f[S | (1 << (i - 1))][now] += f[S][x] + (k * c[now] + (1 << (count[S] - 1)) - 1) % mod * js[S][x] % mod; f[S | (1 << (i - 1))][now] %= mod; js[S | (1 << (i - 1))][now] += js[S][x]; js[S | (1 << (i - 1))][now] %= mod; } } } Ans2 = f[Ma][c[0]]; printf("%lld %lld", Ans1, Ans2); return 0; }
|