简单题。
直接无脑二维 dp $O(nq)$ 可以通过 Subtask 1 得到 5 分的好成绩,但有点没救
考虑这个期望直接 dp 有点难,不妨设随机变量 $X$ 表示有多少个 $r_i > 0$,不难看出:
$$ \begin{align} \mathbb E[S] &= \sum_{i=0}^n \binom{n}{i} \Pr(X = i) \sum^i_{j=1} j^k\\ &= \sum_{i=0}^n \binom{n}{i} \left(\frac{q}{q+1}\right)^i \left(1 - \frac{q}{q+1}\right)^{n-i} \sum^i_{j=1} j^k\\ &= \frac{1}{(q+1)^n} \sum_{i=0}^n \binom{n}{i} q^i \sum^i_{j=1} j^k \end{align} $$
暴力 $O(n^2)$ 就通过了 Subtask 1-2 得到 15 分,可能算一算 $k$ 次幂和就轻松过了 Subtask 1-3,就水到了 30 分暴力。
记 $$ S_k(n)=\sum^n_{i=1}i^k $$
设其 Newton 级数为:
$$ S_k(n) = \sum^{k+1}_{i=0} \Delta^i S_k(0)\binom{n}{i} $$
不妨设 $\delta_i = \Delta^i S_k(0)$,则原式化为
$$ \begin{align} \mathbb E[S] &= \frac{1}{(q+1)^n} \sum_{i=0}^n \binom{n}{i} q^i \sum^{k+1}_{j=0} \Delta^j S_k(0) \binom{i}{j}\\ &= \frac{1}{(q+1)^n} \sum_{i=0}^n \sum_{j=0}^{k+1} \binom{n}{i} \binom{i}{j} q^i \delta_j\\ &= \frac{1}{(q+1)^n} \sum_{j=0}^{k+1}\delta_j \sum_{i=j}^n \binom{n}{i} \binom{i}{j} q^i\\ &= \frac{1}{(q+1)^n} \sum^{k+1}_{j=0} \delta_j \sum^n_{i=j} \binom{n}{j} \binom{n-j}{i-j} q^i\\ &=\frac{1}{(q+1)^n} \sum^{k+1}_{j=0} \delta_j \binom{n}{j} \sum^n_{i=j} \binom{n-j}{i-j} q^i\\ &=\frac{1}{(q+1)^n} \sum^{k+1}_{j=0} \delta_j \binom{n}{j} \sum^{n-j}_{i=0} \binom{n-j}{i} q^{i+j}\\ &=\frac{1}{(q+1)^n} \sum^{k+1}_{j=0} \delta_j q^j \binom{n}{j} \sum^{n-j}_{i=0} \binom{n-j}{i} q^i 1^{n-j-i}\\ &=\frac{1}{(q+1)^n} \sum^{k+1}_{j=0} \delta_j q^j \binom{n}{j} \sum^{n-j}_{i=0} \binom{n-j}{i} q^i 1^{n-j-i}\\ &=\frac{1}{(q+1)^n} \sum^{k+1}_{j=0} \delta_j q^j \binom{n}{j} (q+1)^{n-j}\\ &=\sum^{k+1}_{j=0} \delta_j q^j \binom{n}{j} (q+1)^{-j} \end{align} $$
注意到是个 $k+1$ 次多项式就做完了~
顺便贴一下 2.14 Contest 搬的几个原题的代码(题解见题目页面)
T1(Source: Concrete Mathematics):
// By Qingyu
#include <bits/stdc++.h>
typedef long long ll;
constexpr ll mod = 1e9 + 7;
constexpr ll N = 1e6 + 50;
ll n, K, Q;
ll pri[N], isnt_pri[N], v[N], Sk[N], f[N], cnt;
ll fact[N], inv[N], inv2[N], le[N], re[N];
ll powq[N], powq2[N];
inline ll fastpow(ll x, ll p)
{
ll res = 1;
while (p)
{
if (p & 1) res = res * x % mod;
x = x * x % mod;
p >>= 1;
}
return res;
}
inline void init()
{
isnt_pri[1] = v[1] = fact[0] = 1;
powq[0] = powq2[0] = 1;
const ll factor_q2 = fastpow(Q + 1, mod - 2);
for (int i = 1; i < N; ++i)
{
powq[i] = powq[i - 1] * Q % mod, powq2[i] = powq2[i - 1] * factor_q2 % mod;
fact[i] = i * fact[i - 1] % mod;
if (isnt_pri[i] == false)
{
pri[++cnt] = i;
v[i] = fastpow(i, K);
}
for (int j = 1; j <= cnt && i * pri[j] < N; ++j)
{
isnt_pri[i * pri[j]] = true;
v[i * pri[j]] = v[i] * v[pri[j]] % mod;
if (i % pri[j] == 0) break;
}
}
inv[N - 1] = fastpow(fact[N - 1], mod - 2);
for (int i = 1; i < N; ++i) Sk[i] = (Sk[i - 1] + v[i]) % mod;
for (int i = N - 2; i >= 0; --i) inv[i] = (i + 1) * inv[i + 1] % mod, inv2[i] = inv[i] * fact[i - 1] % mod;
}
inline ll C(ll n, ll m)
{
assert(n >= m);
return fact[n] * inv[m] % mod * inv[n - m] % mod;
}
int main()
{
scanf("%lld%lld%lld", &n, &K, &Q);
init();
for (int j = 1; j <= K + 2; ++j)
{
for (int i = 0; i <= j; ++i)
{
f[j] = (f[j] + C(j, i) * powq[i] % mod * Sk[i]) % mod;
}
f[j] = f[j] * powq2[j] % mod;
}
if (n <= K + 2) printf("%lld", f[n]);
else
{
ll ans = 0;
le[0] = re[K + 3] = 1;
for (int i = 1; i <= K + 2; ++i) le[i] = le[i - 1] * (n - i) % mod;
for (int i = K + 2; i >= 1; --i) re[i] = re[i + 1] * (n - i) % mod;
for (int i = 1; i <= K + 2; ++i)
{
ll cur = f[i] * le[i - 1] % mod * re[i + 1] % mod * inv[i - 1] % mod * inv[K + 2 - i] % mod;
if ((K - i) % 2 == 0) ans += cur;
else ans -= cur;
}
printf("%lld", (ans % mod + mod) % mod);
}
return 0;
}
T2(Source: 2020 Wannafly Camp Day 1):
// By Qingyu
#include <bits/stdc++.h>
const int N = 4e5 + 50;
//----基本定义----
int n, m;
int a[N], max[N << 2], sec[N << 2], max_cnt[N << 2], tag_min[N << 2], root[N << 2], tot;
std::queue<int> memory_pool;
// 外层劼司机线段树 (Segment Tree Beats)
// max, sec, max_cnt, tag_min:区间线段树对应区间的最大值、次大值、最大值的数量、区间取 min 的 tag
inline int lc(int o) { return o << 1; }
inline int rc(int o) { return o << 1 | 1; }
// 区间线段树的左右儿子
struct tree
{
int lc, rc, sum;
} node[N << 6];
// 权值线段树的信息
//----权值线段树----
inline int node_new()
{
if (memory_pool.empty() == false)
{
int x = memory_pool.front();
memory_pool.pop();
return x;
}
return ++tot;
}
inline void node_delete(int o)
{
if (o)
{
node[o].lc = node[o].rc = node[o].sum = 0;
memory_pool.push(o);
}
}
inline void node_kill(int o)
{
if (o)
{
if (node[o].lc) node_kill(node[o].lc);
if (node[o].rc) node_kill(node[o].rc);
node_delete(o);
}
}
// 将以 x, y 为根的权值线段树合并为新树 z,返回 z;
int merge1(int x, int y)
{
if (x != 0 or y != 0)
{
int z = node_new();
node[z].sum = node[x].sum + node[y].sum;
node[z].lc = merge1(node[x].lc, node[y].lc);
node[z].rc = merge1(node[x].rc, node[y].rc);
return z;
}
else return 0;
}
// 将以 y 为根的权值线段树合并到 x 上,并删除 y
int merge2(int x, int y)
{
if (x != 0 and y != 0)
{
node[x].sum += node[y].sum;
node[x].lc = merge2(node[x].lc, node[y].lc);
node[x].rc = merge2(node[x].rc, node[y].rc);
node_delete(y);
return x;
}
else return x | y;
}
// 将以 y 为根的权值线段树合并到 x 上,保留 y
int merge3(int x, int y)
{
if (y == 0) return x;
if (x == 0) x = node_new();
node[x].sum += node[y].sum;
node[x].lc = merge3(node[x].lc, node[y].lc);
node[x].rc = merge3(node[x].rc, node[y].rc);
if (node[x].sum != 0) return x;
else return node_delete(x), 0;
}
// 权值线段树的插入
int insert(int o, int l, int r, int p, int v)
{
if (o == 0) o = node_new();
node[o].sum += v;
if (l != r)
{
const int mid = l + r >> 1;
if (p <= mid) node[o].lc = insert(node[o].lc, l, mid, p, v);
else node[o].rc = insert(node[o].rc, mid + 1, r, p, v);
}
return o;
}
// ----区间线段树----
inline void push_up(int o)
{
int left = lc(o), right = rc(o);
if (max[left] > max[right])
{
max[o] = max[left];
max_cnt[o] = max_cnt[left];
sec[o] = std::max(sec[left], max[right]);
}
else if (max[left] < max[right])
{
max[o] = max[right];
max_cnt[o] = max_cnt[right];
sec[o] = std::max(max[left], sec[right]);
}
else
{
max[o] = max[left];
max_cnt[o] = max_cnt[left] + max_cnt[right];
sec[o] = std::max(sec[left], sec[right]);
}
}
// 将标记 v 下放至点 o,即将 o 所对应的区间对 v 取 min ,且保证 sec[o] < min(tag, v)
inline void maintain(int o, int v)
{
if (max[o] > v)
{
v = tag_min[o] = std::min(v, tag_min[o]);
// 将 o 对应的权值线段树中,最大值更新为 x
root[o] = insert(root[o], 1, n, max[o], -max_cnt[o]);
root[o] = insert(root[o], 1, n, v, max_cnt[o]);
max[o] = v;
}
}
// 需保证区间内两个儿子满足 sec[o] < min(tag[o], tag[child[o]])
inline void push_down(int o)
{
if (tag_min[o] <= n)
{
maintain(lc(o), tag_min[o]);
maintain(rc(o), tag_min[o]);
tag_min[o] = 0x3f3f3f3f;
}
}
// [ql, qr] 的值域线段树与 v 取 min
int modify(int o, int l, int r, int ql, int qr, int v)
{
if (r >= ql && l <= qr && max[o] > v) // 调用 modify 时没有检测询问区间与 o 的关系,因此在处理数据前进行检测
{
int p = 0;
if (l >= ql && r <= qr && sec[o] < v)
{
p = insert(p, 1, n, max[o], -max_cnt[o]);
p = insert(p, 1, n, v, max_cnt[o]);
// 考虑需将 root -> o 所有点的 max 减去 max_cnt,v 加上 max_cnt,这部分的贡献可抽象为权值线段树 p。
// 最后将 o 对应的权值线段树直接下放取 min 操作
maintain(o, v);
return p;
}
push_down(o); // 打在 o 上的标记必定满足 sec[o] < tag[o],因此可以直接下放
const int mid = l + r >> 1;
p = merge2(modify(lc(o), l, mid, ql, qr, v), modify(rc(o), mid + 1, r, ql, qr, v)); // 合并贡献,此后这两块便无用,可以删除
root[o] = merge3(root[o], p); // 需要将总贡献返回,因此必须保留 p
push_up(o);
return p;
}
else return 0;
}
void build(int o, int l, int r)
{
tag_min[o] = 0x3f3f3f3f;
if (l == r)
{
max[o] = a[l], max_cnt[o] = 1, sec[o] = 0;
root[o] = insert(0, 1, n, a[l], 1);
}
else
{
const int mid = l + r >> 1;
build(lc(o), l, mid);
build(rc(o), mid + 1, r);
root[o] = merge1(root[lc(o)], root[rc(o)]);
push_up(o);
}
}
// 将区间线段树 [ql, qr] 中对应的权值线段树的根提取至 tar
void select_key_nodes(int o, int l, int r, int ql, int qr, std::vector<int> &tar)
{
if (ql <= l && qr >= r)
{
tar.push_back(root[o]);
}
else
{
const int mid = l + r >> 1;
push_down(o);
if (ql <= mid) select_key_nodes(lc(o), l, mid, ql, qr, tar);
if (qr > mid) select_key_nodes(rc(o), mid + 1, r, ql, qr, tar);
push_up(o);
}
}
// 在 src 中所有权值线段树的并重中,找 [l,r] 内的第 k 大
int query(int l, int r, int k, std::vector<int> &src)
{
if (l == r) return l;
int sum = 0, size = src.size();
const int mid = l + r >> 1;
for (auto v : src) sum += node[node[v].lc].sum;
if (sum >= k)
{
for (int i = 0; i < size; ++i) src[i] = node[src[i]].lc;
return query(l, mid, k, src);
}
else
{
for (int i = 0; i < size; ++i) src[i] = node[src[i]].rc;
return query(mid + 1, r, k - sum, src);
}
throw;
}
inline char nc()
{
static char buf[1000000], *p1 = buf, *p2 = buf;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? EOF : *p1++;
}
inline int read()
{
int res = 0;
char ch;
do ch = nc(); while (ch < 48 or ch > 57);
do res = res * 10 + ch - 48, ch = nc(); while (ch >= 48 && ch <= 57);
return res;
}
int main()
{
n = read(), m = read();
for (int i = 1; i <= n; ++i) a[i] = read();
build(1, 1, n);
for (int i = 1; i <= m; ++i)
{
int op = read(), l = read(), r = read(), K = read();
if (op == 1)
{
if (K <= n)
{
node_kill(modify(1, 1, n, l, r, K));
}
}
else
{
std::vector<int> cur;
select_key_nodes(1, 1, n, l, r, cur);
printf("%d\n", query(1, n, K, cur));
}
}
return 0;
}
T3:(Source: 忘了但貌似是 XVIII OpenCup ?????)
// By Qingyu
// By Qingyu
#include <bits/stdc++.h>
constexpr int N = 1e5 + 50;
constexpr int T = 50;
constexpr int mod = 1e9 + 7;
constexpr int base = 53;
typedef long long ll;
typedef unsigned long long ull;
ll n, a[N], fact[N], inv[N];
ll vertex_count[T][N][T], edge_count[T][N]; // vertex_id[i][j][k]: 在 i 个连通块的第 j 号状态中的第 k 个连通块的点数, edge_count[i][j]:在 i 个连通块的第 j 号状态中所有连通块的边数之和。
ll status_cnt[T], block_size[T];
std::unordered_map<ull, int> map[T];
inline ll fastpow(ll x, ll p)
{
ll res = 1;
while (p)
{
if (p & 1) res = res * x % mod;
x = x * x % mod;
p >>= 1;
}
return res;
}
inline void init()
{
fact[0] = 1;
for (int i = 1; i < N; ++i) fact[i] = i * fact[i - 1] % mod;
inv[N - 1] = fastpow(fact[N - 1], mod - 2);
for (int i = N - 2; i >= 0; --i) inv[i] = (i + 1) * inv[i + 1] % mod;
}
inline void get_id(int total_block_cnt, int rest_vertex, int current_block_cnt, int last_block_size)
{
int rest_block_cnt = total_block_cnt - current_block_cnt + 1;
if (rest_block_cnt * last_block_size > rest_vertex) return;
else if (current_block_cnt == total_block_cnt + 1)
{
++status_cnt[total_block_cnt];
int id = status_cnt[total_block_cnt];
ull hash_value = 0;
for (int i = 1; i <= total_block_cnt; ++i)
{
vertex_count[total_block_cnt][id][i] = block_size[i];
edge_count[total_block_cnt][id] += block_size[i] * (block_size[i] - 1) >> 1;
hash_value = hash_value * base + block_size[i];
}
map[total_block_cnt][hash_value] = id;
}
else if (current_block_cnt == total_block_cnt)
{
block_size[current_block_cnt] = rest_vertex;
get_id(total_block_cnt, 0, current_block_cnt + 1, rest_vertex);
}
else
{
for (int i = last_block_size; i <= rest_vertex; ++i)
{
block_size[current_block_cnt] = i;
get_id(total_block_cnt, rest_vertex - i, current_block_cnt + 1, i);
}
}
}
inline ll Comb(ll n, ll m)
{
if (n < m) return 0;
if (n < 0) return 0;
return fact[n] * inv[m] % mod * inv[n - m] % mod;
}
inline ll Perm(ll n, ll m)
{
if (n < m) return 0;
if (n < 0) return 0;
return fact[n] * inv[n - m] % mod;
}
inline ll dp()
{
static ll f[T][N], g[T][T]; // f[i][j] 表示连出 i 个连通块中的 j 号状态的方案数。
static ll vc[N];
for (int i = 1; i <= status_cnt[n]; ++i) f[n][i] = 1;
for (int i = n; i > 1; --i)
{
for (int j = 1; j <= status_cnt[i]; ++j)
{
ll connect_nontree_edge_factor = Perm(edge_count[i][j] - a[i], a[i - 1] - a[i] - 1);
memset(g, 0, sizeof g);
// 枚举 k,l 两个点集将其 merge
for (int x = 1; x <= i; ++x) vc[x] = vertex_count[i][j][x];
for (int k = 1; k <= i; ++k)
{
for (int l = k + 1; l <= i; ++l)
{
if (!g[vc[k]][vc[l]])
{
ull new_hash_value = 0;
bool ok = true;
for (int w = 1; w <= i; ++w)
{
if (w != k && w != l)
{
if (ok && vc[w] > vc[k] + vc[l])
{
ok = false;
new_hash_value = new_hash_value * base + vc[k] + vc[l];
}
new_hash_value = new_hash_value * base + vc[w];
}
}
if (ok) new_hash_value = new_hash_value * base + vc[k] + vc[l];
g[vc[k]][vc[l]] = map[i - 1][new_hash_value];
}
// 转移至 g[vc[k]][vc[l]]
f[i - 1][g[vc[k]][vc[l]]] += f[i][j] * connect_nontree_edge_factor % mod * vc[k] * vc[l] % mod;
f[i - 1][g[vc[k]][vc[l]]] %= mod;
}
}
}
}
// 剩余 C(n, 2) - a[1] 条边随便连接不会影响答案。
return f[1][1] * fact[Comb(n, 2) - a[1]] % mod;
}
inline char nc()
{
static char buf[1000000], *p1 = buf, *p2 = buf;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? EOF : *p1++;
}
inline int read()
{
int res = 0;
char ch;
do ch = nc(); while (ch < 48 or ch > 57);
do res = res * 10 + ch - 48, ch = nc(); while (ch >= 48 && ch <= 57);
return res;
}
int main()
{
init();
n = read();
for (int i = 1; i < n; ++i) a[n - i] = read();
for (int i = 1; i <= n; ++i) get_id(i, n, 1, 1);
printf("%lld", dp());
return 0;
}
所以是一个套路题一个休闲的数据结构和一个提高组dp