abcdeffa's Blog

当局者迷,旁观者清。

0%

GMOJ S4816 【label】

Description

有一棵有 $n$ 个节点的树,节点的编号从 $1$ 到 $n$ 。

现在要给树上的每一个节点赋一个 $[1, m]$ 之间的权值,要求相邻两个节点的权值差的绝对值大于等于 $k$ ,求合法的方案数,答案对 $(10^9+7)$ 取模。

共有 $T$ 组数据。

Solution

考虑使用 树形DP 来解题。

设 $f[i][j]$ 表示以 $i$ 为根的子树中 $i$ 号点填数 $j$ 的方案数。

用前缀和优化一下就可以做到 $m \leq 10^4$ 。

发现对于同一个点的 DP 值,都是对称的,并且中间有一段相同的,一边上的不同的数的个数小于等于 $(n - 1)k$ ,因此只需要存储前 $[(n - 1)k + 1]$ 个即可。

然后就可以做 $m \leq 10^9$ 了。

建议 $m \leq 10^4$ 和 $m \leq 10^9$ 的分开做。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#include <cstdio>
#include <cstring>
#define maxN 110
#define maxM 20010
#define mod 1000000007LL
struct node{ long long x, y, g; } b[maxN << 1];
long long su[maxN][maxM], f[maxN][maxM], son[maxN][maxN], h[maxN];
long long len = 0, n = 0, m = 0, k = 0, p = 0;
long long min (long long x, long long y)
{
return x < y ? x : y;
}
long long max (long long x, long long y)
{
return x > y ? x : y;
}
void ins (long long x, long long y)
{
len++;
b[len].x = x;
b[len].y = y;
b[len].g = h[x];
h[x] = len;
}
void dfs (long long x, long long fa)
{
for(long long i = h[x];i;i = b[i].g)
{
long long y = b[i].y;
if(y != fa)
{
son[x][++son[x][0]] = y;
dfs(y, x);
}
}
}
long long sum (long long x, long long K)
{
long long res = 0;
long long t = min(p, K);
res = su[x][t] % mod;
long long c = K - (m - p);
if(c > 0)
{
K -= c;
res += (su[x][p] + mod - su[x][p - c]) % mod;
res %= mod;
}
K -= t;
res += f[x][p] * K;
res %= mod;
return res;
}
void dp (long long x)
{
if(!son[x][0])
{
for(long long i = 1;i <= m; i++)
{
f[x][i] = 1;
su[x][i] = su[x][i - 1] + f[x][i];
su[x][i] %= mod;
}
return ;
}
for(long long i = 1;i <= son[x][0]; i++)
{
dp(son[x][i]);
}
for(long long j = 1;j <= m; j++)
{
f[x][j] = 1;
bool flag = false;
for(long long i = 1;i <= son[x][0]; i++)
{
long long da = 0;
if(j > k)
{
flag = true;
da += su[son[x][i]][j - k];
da %= mod;
}
long long L = (j + k) + (j - k == j + k);
if(L <= m)
{
flag = true;
da += (su[son[x][i]][m] + mod - su[son[x][i]][L - 1]) % mod;
da %= mod;
}
f[x][j] *= da;
f[x][j] %= mod;
}
f[x][j] *= flag;
su[x][j] = (su[x][j - 1] + f[x][j]) % mod;
}
}
void dp2 (long long x)
{
if(!son[x][0])
{
for(long long i = 1;i <= p; i++)
{
f[x][i] = 1;
su[x][i] = su[x][i - 1] + f[x][i];
su[x][i] %= mod;
}
return ;
}
for(long long i = 1;i <= son[x][0]; i++)
{
dp2(son[x][i]);
}
for(long long j = 1;j <= p; j++)
{
f[x][j] = 1;
bool flag = false;
for(long long i = 1;i <= son[x][0]; i++)
{
long long da = 0;
if(j > k)
{
flag = true;
da += sum(son[x][i], j - k);
da %= mod;
}
long long L = (j + k) + (j - k == j + k);
if(L <= m)
{
flag = true;
da += (sum(son[x][i], m) + mod - sum(son[x][i], L - 1)) % mod;
da %= mod;
}
f[x][j] *= da;
f[x][j] %= mod;
}
f[x][j] *= flag;
su[x][j] = (su[x][j - 1] + f[x][j]) % mod;
}
}
int main ()
{
long long T = 0;
scanf("%lld", &T);
while(T--)
{
memset(f, 0, sizeof(f));
scanf("%lld %lld %lld", &n, &m, &k);
for(long long i = 1;i <= n - 1; i++)
{
long long x = 0, y = 0;
scanf("%lld %lld", &x, &y);
ins(x, y);
ins(y, x);
}
dfs(1, 0);
p = m;
if(m > 2 * (n - 1) * k)
{
p = (n - 1) * k + 1;
dp2(1);
}
else
{
dp(1);
}
long long ans = 0;
for(long long i = 1;i <= p; i++)
{
ans += f[1][i];
ans %= mod;
}
if(m > 2 * (n - 1) * k)
{
ans = ans * 2 % mod;
ans = (ans + f[1][p] * (m - 2 * p)) % mod;
}
printf("%lld\n", ans);
len = 0;
for(long long i = 1;i <= n; i++)
{
h[i] = 0;
son[i][0] = 0;
}
}
return 0;
}