abcdeffa's Blog

当局者迷,旁观者清。

0%

GMOJ S6879 【T1 出了个大阴间题】

Description

你有 $n$ 个二元组,你可以将 $(a_1, b_1)$ 与 $(a_2, b_2)$ 合并成:

,需要的花费为 $(ka + b_1 + b_2)$。

现在你有 $n$ 个二元组 $(a_1, 0), (a_2, 0), . . . ,(a_n, 0)$,你需要按照一个排列 $p_1, p_2, p_3, . . . , p_n$ 的顺序合并,就是先将 $(a_{p_1} , 0)$ 与 $(a_{p_2} , 0)$ 合并,再将所得结果与 $(a_{p_3} , 0)$ 合并,以此类推。

你希望最后结果 $(a, b)$ 中 $a$ 最大,并希望求出在 $a$ 最大的条件下所有合法排列的合并代价总和对 $(10^9 + 7)$ 取模后的结果。

时间限制 2s,空间限制 512MB。

Solution

看到 $n$ 这么小,考虑状压 DP。

设 $f_{S, i}$ 表示当前已经合并了的状态为 $S$,结果的 $a$ 值为 $i$ 的代价和。

因为 $1 \leq a_i \leq 10^9$,于是我们考虑把所有可能合并出来的数求出来,然后按照这个对 $a_i$ 进行离散化即可,由于显然这些数的总数小于 $2n$,然后暴力转移即可。

比赛的时候因为数组开小和没有取模而 FST 了,出题人实在是高。

据说有 $O(n^2)$ 做法,然而我并不会。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <map>
#include <cstdio>
#include <cstring>
#define maxN 18
#define mod 1000000007
std::map<int, int> find, cnt;
long long a[maxN + 1], b[maxN + 1], c[maxN + 1];
long long count[1 << maxN], f[1 << maxN][37], js[1 << maxN][37];
long long max (long long x, long long y)
{
return x > y ? x : y;
}
void px1 (long long l, long long r)
{
long long x = l, y = r, mid = a[(l + r) >> 1];
while(x <= y)
{
while(a[x] < mid)
{
x++;
}
while(a[y] > mid)
{
y--;
}
if(x <= y)
{
long long t = a[x];
a[x] = a[y];
a[y] = t;
x++;
y--;
}
}
if(l < y)
{
px1(l, y);
}
if(x < r)
{
px1(x, r);
}
}
void px2 (long long l, long long r)
{
long long x = l, y = r, mid = b[(l + r) >> 1];
while(x <= y)
{
while(b[x] < mid)
{
x++;
}
while(b[y] > mid)
{
y--;
}
if(x <= y)
{
long long t = b[x];
b[x] = b[y];
b[y] = t;
x++;
y--;
}
}
if(l < y)
{
px2(l, y);
}
if(x < r)
{
px2(x, r);
}
}
int main ()
{
long long n = 0, k = 0;
scanf("%lld %lld", &n, &k);
for(long long i = 1;i <= n; i++)
{
scanf("%lld", &a[i]);
b[i] = a[i];
cnt[a[i]]++;
}
px1(1, n);
b[0] = n;
long long Ans1 = 0, Ans2 = 0;
for(long long i = 1;i <= n; i++)
{
if(cnt[a[i]] >= 2)
{
if(!cnt[a[i] + 1])
{
b[++b[0]] = a[i] + 1;
}
cnt[a[i] + 1]++;
}
}
px2(1, b[0]);
Ans1 = b[b[0]];
c[++c[0]] = b[1];
for(long long i = 2;i <= b[0]; i++)
{
if(b[i] != b[i - 1])
{
c[++c[0]] = b[i];
}
}
for(long long i = 1;i <= c[0]; i++)
{
find[c[i]] = i;
}
memset(f, -127 / 3, sizeof(f));
for(long long i = 1;i <= n; i++)
{
f[1 << (i - 1)][find[a[i]]] = 0;
js[1 << (i - 1)][find[a[i]]] = 1;
}
const long long Ma = (1 << n) - 1;
for(long long S = 0;S <= Ma; S++)
{
long long x = S;
while(x)
{
count[S] += (x & 1);
x >>= 1;
}
}
for(long long S = 0;S <= Ma; S++)
{
for(long long i = 1;i <= n; i++)
{
if(S & (1 << (i - 1)))
{
continue;
}
for(long long x = 1;x <= 36; x++)
{
if(f[S][x] < 0)
{
continue;
}
long long dq = find[a[i]];
long long now = max(x, dq) + (x == dq);
if(f[S | (1 << (i - 1))][now] < 0)
{
f[S | (1 << (i - 1))][now] = 0;
}
f[S | (1 << (i - 1))][now] += f[S][x] + (k * c[now] + (1 << (count[S] - 1)) - 1) % mod * js[S][x] % mod;
f[S | (1 << (i - 1))][now] %= mod;
js[S | (1 << (i - 1))][now] += js[S][x];
js[S | (1 << (i - 1))][now] %= mod;
}
}
}
Ans2 = f[Ma][c[0]];
printf("%lld %lld", Ans1, Ans2);
return 0;
}