Qingyu的博客

博客

你OJ278的详细题解

2020-02-15 03:51:43 By Qingyu

简单题。

直接无脑二维 dp $O(nq)$ 可以通过 Subtask 1 得到 5 分的好成绩,但有点没救

考虑这个期望直接 dp 有点难,不妨设随机变量 $X$ 表示有多少个 $r_i > 0$,不难看出:

$$ \begin{align} \mathbb E[S] &= \sum_{i=0}^n \binom{n}{i} \Pr(X = i) \sum^i_{j=1} j^k\\ &= \sum_{i=0}^n \binom{n}{i} \left(\frac{q}{q+1}\right)^i \left(1 - \frac{q}{q+1}\right)^{n-i} \sum^i_{j=1} j^k\\ &= \frac{1}{(q+1)^n} \sum_{i=0}^n \binom{n}{i} q^i \sum^i_{j=1} j^k \end{align} $$

暴力 $O(n^2)$ 就通过了 Subtask 1-2 得到 15 分,可能算一算 $k$ 次幂和就轻松过了 Subtask 1-3,就水到了 30 分暴力。

记 $$ S_k(n)=\sum^n_{i=1}i^k $$

设其 Newton 级数为:

$$ S_k(n) = \sum^{k+1}_{i=0} \Delta^i S_k(0)\binom{n}{i} $$

不妨设 $\delta_i = \Delta^i S_k(0)$,则原式化为

$$ \begin{align} \mathbb E[S] &= \frac{1}{(q+1)^n} \sum_{i=0}^n \binom{n}{i} q^i \sum^{k+1}_{j=0} \Delta^j S_k(0) \binom{i}{j}\\ &= \frac{1}{(q+1)^n} \sum_{i=0}^n \sum_{j=0}^{k+1} \binom{n}{i} \binom{i}{j} q^i \delta_j\\ &= \frac{1}{(q+1)^n} \sum_{j=0}^{k+1}\delta_j \sum_{i=j}^n \binom{n}{i} \binom{i}{j} q^i\\ &= \frac{1}{(q+1)^n} \sum^{k+1}_{j=0} \delta_j \sum^n_{i=j} \binom{n}{j} \binom{n-j}{i-j} q^i\\ &=\frac{1}{(q+1)^n} \sum^{k+1}_{j=0} \delta_j \binom{n}{j} \sum^n_{i=j} \binom{n-j}{i-j} q^i\\ &=\frac{1}{(q+1)^n} \sum^{k+1}_{j=0} \delta_j \binom{n}{j} \sum^{n-j}_{i=0} \binom{n-j}{i} q^{i+j}\\ &=\frac{1}{(q+1)^n} \sum^{k+1}_{j=0} \delta_j q^j \binom{n}{j} \sum^{n-j}_{i=0} \binom{n-j}{i} q^i 1^{n-j-i}\\ &=\frac{1}{(q+1)^n} \sum^{k+1}_{j=0} \delta_j q^j \binom{n}{j} \sum^{n-j}_{i=0} \binom{n-j}{i} q^i 1^{n-j-i}\\ &=\frac{1}{(q+1)^n} \sum^{k+1}_{j=0} \delta_j q^j \binom{n}{j} (q+1)^{n-j}\\ &=\sum^{k+1}_{j=0} \delta_j q^j \binom{n}{j} (q+1)^{-j} \end{align} $$

注意到是个 $k+1$ 次多项式就做完了~

顺便贴一下 2.14 Contest 搬的几个原题的代码(题解见题目页面)

T1(Source: Concrete Mathematics):

// By Qingyu
#include <bits/stdc++.h>

typedef long long ll;
constexpr ll mod = 1e9 + 7;
constexpr ll N = 1e6 + 50;

ll n, K, Q;
ll pri[N], isnt_pri[N], v[N], Sk[N], f[N], cnt;
ll fact[N], inv[N], inv2[N], le[N], re[N];
ll powq[N], powq2[N];

inline ll fastpow(ll x, ll p)
{
    ll res = 1;
    while (p)
    {
        if (p & 1) res = res * x % mod;
        x = x * x % mod;
        p >>= 1;
    }
    return res;
}

inline void init()
{
    isnt_pri[1] = v[1] = fact[0] = 1;
    powq[0] = powq2[0] = 1;
    const ll factor_q2 = fastpow(Q + 1, mod - 2);
    for (int i = 1; i < N; ++i)
    {
        powq[i] = powq[i - 1] * Q % mod, powq2[i] = powq2[i - 1] * factor_q2 % mod;
        fact[i] = i * fact[i - 1] % mod;
        if (isnt_pri[i] == false)
        {
            pri[++cnt] = i;
            v[i] = fastpow(i, K);
        }
        for (int j = 1; j <= cnt && i * pri[j] < N; ++j)
        {
            isnt_pri[i * pri[j]] = true;
            v[i * pri[j]] = v[i] * v[pri[j]] % mod;
            if (i % pri[j] == 0) break;
        }
    }
    inv[N - 1] = fastpow(fact[N - 1], mod - 2);
    for (int i = 1; i < N; ++i) Sk[i] = (Sk[i - 1] + v[i]) % mod;
    for (int i = N - 2; i >= 0; --i) inv[i] = (i + 1) * inv[i + 1] % mod, inv2[i] = inv[i] * fact[i - 1] % mod;
}

inline ll C(ll n, ll m)
{
    assert(n >= m);
    return fact[n] * inv[m] % mod * inv[n - m] % mod;
}

int main()
{
    scanf("%lld%lld%lld", &n, &K, &Q);
    init();
    for (int j = 1; j <= K + 2; ++j) 
    {
        for (int i = 0; i <= j; ++i)
        {
            f[j] = (f[j] + C(j, i) * powq[i] % mod * Sk[i]) % mod;
        }
        f[j] = f[j] * powq2[j] % mod;
    }
    if (n <= K + 2) printf("%lld", f[n]);
    else
    {
        ll ans = 0;
        le[0] = re[K + 3] = 1;
        for (int i = 1; i <= K + 2; ++i) le[i] = le[i - 1] * (n - i) % mod;
        for (int i = K + 2; i >= 1; --i) re[i] = re[i + 1] * (n - i) % mod;
        for (int i = 1; i <= K + 2; ++i)
        {
            ll cur = f[i] * le[i - 1] % mod * re[i + 1] % mod * inv[i - 1] % mod * inv[K + 2 - i] % mod;
            if ((K - i) % 2 == 0) ans += cur;
            else ans -= cur;
        }
        printf("%lld", (ans % mod + mod) % mod);
    }
    return 0;
}

T2(Source: 2020 Wannafly Camp Day 1):

// By Qingyu
#include <bits/stdc++.h>

const int N = 4e5 + 50;

//----基本定义---- 

int n, m;
int a[N], max[N << 2], sec[N << 2], max_cnt[N << 2], tag_min[N << 2], root[N << 2], tot;
std::queue<int> memory_pool;

// 外层劼司机线段树 (Segment Tree Beats) 
// max, sec, max_cnt, tag_min:区间线段树对应区间的最大值、次大值、最大值的数量、区间取 min 的 tag 

inline int lc(int o) { return o << 1; }
inline int rc(int o) { return o << 1 | 1; }
// 区间线段树的左右儿子 

struct tree
{
    int lc, rc, sum;
} node[N << 6];
// 权值线段树的信息 

//----权值线段树---- 

inline int node_new()
{
    if (memory_pool.empty() == false)
    {
        int x = memory_pool.front();
        memory_pool.pop();
        return x;
    }
    return ++tot;
}

inline void node_delete(int o)
{
    if (o)
    {
        node[o].lc = node[o].rc = node[o].sum = 0;
        memory_pool.push(o);
    }
}

inline void node_kill(int o)
{
    if (o)
    {
        if (node[o].lc) node_kill(node[o].lc);
        if (node[o].rc) node_kill(node[o].rc);
        node_delete(o);
    }
}

// 将以 x, y 为根的权值线段树合并为新树 z,返回 z; 
int merge1(int x, int y)
{
    if (x != 0 or y != 0)
    {
        int z = node_new();
        node[z].sum = node[x].sum + node[y].sum;
        node[z].lc = merge1(node[x].lc, node[y].lc);
        node[z].rc = merge1(node[x].rc, node[y].rc);
        return z;
    }
    else return 0;
}

// 将以 y 为根的权值线段树合并到 x 上,并删除 y 
int merge2(int x, int y)
{
    if (x != 0 and y != 0)
    {
        node[x].sum += node[y].sum;
        node[x].lc = merge2(node[x].lc, node[y].lc);
        node[x].rc = merge2(node[x].rc, node[y].rc);
        node_delete(y);
        return x;
    }
    else return x | y;
}

// 将以 y 为根的权值线段树合并到 x 上,保留 y 
int merge3(int x, int y)
{
    if (y == 0) return x;
    if (x == 0) x = node_new();
    node[x].sum += node[y].sum;
    node[x].lc = merge3(node[x].lc, node[y].lc);
    node[x].rc = merge3(node[x].rc, node[y].rc);
    if (node[x].sum != 0) return x;
    else return node_delete(x), 0; 
}

// 权值线段树的插入 
int insert(int o, int l, int r, int p, int v)
{
    if (o == 0) o = node_new();
    node[o].sum += v;
    if (l != r)
    {
        const int mid = l + r >> 1;
        if (p <= mid) node[o].lc = insert(node[o].lc, l, mid, p, v);
        else node[o].rc = insert(node[o].rc, mid + 1, r, p, v);
    }
    return o;
}

// ----区间线段树---- 

inline void push_up(int o)
{
    int left = lc(o), right = rc(o);
    if (max[left] > max[right])
    {
        max[o] = max[left];
        max_cnt[o] = max_cnt[left];
        sec[o] = std::max(sec[left], max[right]);
    }
    else if (max[left] < max[right])
    {
        max[o] = max[right];
        max_cnt[o] = max_cnt[right];
        sec[o] = std::max(max[left], sec[right]);
    }
    else
    {
        max[o] = max[left];
        max_cnt[o] = max_cnt[left] + max_cnt[right];
        sec[o] = std::max(sec[left], sec[right]);
    }
}

// 将标记 v 下放至点 o,即将 o 所对应的区间对 v 取 min ,且保证 sec[o] < min(tag, v) 
inline void maintain(int o, int v)
{
    if (max[o] > v)
    {
        v = tag_min[o] = std::min(v, tag_min[o]);
        // 将 o 对应的权值线段树中,最大值更新为 x
        root[o] = insert(root[o], 1, n, max[o], -max_cnt[o]);
        root[o] = insert(root[o], 1, n, v, max_cnt[o]);
        max[o] = v;
    }
}

// 需保证区间内两个儿子满足 sec[o] < min(tag[o], tag[child[o]])
inline void push_down(int o)
{
    if (tag_min[o] <= n)
    {
        maintain(lc(o), tag_min[o]);
        maintain(rc(o), tag_min[o]);
        tag_min[o] = 0x3f3f3f3f;
    }
}

// [ql, qr] 的值域线段树与 v 取 min 
int modify(int o, int l, int r, int ql, int qr, int v)
{
    if (r >= ql && l <= qr && max[o] > v) // 调用 modify 时没有检测询问区间与 o 的关系,因此在处理数据前进行检测 
    {
        int p = 0;
        if (l >= ql && r <= qr && sec[o] < v)
        {
            p = insert(p, 1, n, max[o], -max_cnt[o]);
            p = insert(p, 1, n, v, max_cnt[o]);
            // 考虑需将 root -> o 所有点的 max 减去 max_cnt,v 加上 max_cnt,这部分的贡献可抽象为权值线段树 p。
            // 最后将 o 对应的权值线段树直接下放取 min 操作 
            maintain(o, v);
            return p;
        }
        push_down(o); // 打在 o 上的标记必定满足 sec[o] < tag[o],因此可以直接下放
        const int mid = l + r >> 1;
        p = merge2(modify(lc(o), l, mid, ql, qr, v), modify(rc(o), mid + 1, r, ql, qr, v)); // 合并贡献,此后这两块便无用,可以删除 
        root[o] = merge3(root[o], p); // 需要将总贡献返回,因此必须保留 p 
        push_up(o);
        return p;
    }
    else return 0;
}

void build(int o, int l, int r)
{
    tag_min[o] = 0x3f3f3f3f;
    if (l == r)
    {
        max[o] = a[l], max_cnt[o] = 1, sec[o] = 0;
        root[o] = insert(0, 1, n, a[l], 1);
    }
    else
    {
        const int mid = l + r >> 1;
        build(lc(o), l, mid);
        build(rc(o), mid + 1, r);
        root[o] = merge1(root[lc(o)], root[rc(o)]);
        push_up(o);
    }
}

// 将区间线段树 [ql, qr] 中对应的权值线段树的根提取至 tar 
void select_key_nodes(int o, int l, int r, int ql, int qr, std::vector<int> &tar)
{
    if (ql <= l && qr >= r)
    {
        tar.push_back(root[o]);
    }
    else
    {
        const int mid = l + r >> 1;
        push_down(o);
        if (ql <= mid) select_key_nodes(lc(o), l, mid, ql, qr, tar);
        if (qr > mid) select_key_nodes(rc(o), mid + 1, r, ql, qr, tar);
        push_up(o);
    }
}

// 在 src 中所有权值线段树的并重中,找 [l,r] 内的第 k 大 
int query(int l, int r, int k, std::vector<int> &src)
{
    if (l == r) return l;
    int sum = 0, size = src.size();
    const int mid = l + r >> 1;
    for (auto v : src) sum += node[node[v].lc].sum;
    if (sum >= k)
    {
        for (int i = 0; i < size; ++i) src[i] = node[src[i]].lc;
        return query(l, mid, k, src);
    }
    else
    {
        for (int i = 0; i < size; ++i) src[i] = node[src[i]].rc;
        return query(mid + 1, r, k - sum, src); 
    }
    throw;
}

inline char nc()
{
    static char buf[1000000], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? EOF : *p1++;
}

inline int read()
{
    int res = 0;
    char ch;
    do ch = nc(); while (ch < 48 or ch > 57);
    do res = res * 10 + ch - 48, ch = nc(); while (ch >= 48 && ch <= 57);
    return res;
}

int main()
{
    n = read(), m = read();
    for (int i = 1; i <= n; ++i) a[i] = read();
    build(1, 1, n);
    for (int i = 1; i <= m; ++i)
    {
        int op = read(), l = read(), r = read(), K = read();
        if (op == 1)
        {
            if (K <= n)
            {
                node_kill(modify(1, 1, n, l, r, K));
            }
        }
        else
        {
            std::vector<int> cur;
            select_key_nodes(1, 1, n, l, r, cur);
            printf("%d\n", query(1, n, K, cur));
        }
    }
    return 0;
}

T3:(Source: 忘了但貌似是 XVIII OpenCup ?????)

// By Qingyu
// By Qingyu
#include <bits/stdc++.h>

constexpr int N = 1e5 + 50;
constexpr int T = 50;
constexpr int mod = 1e9 + 7;
constexpr int base = 53;
typedef long long ll;
typedef unsigned long long ull;

ll n, a[N], fact[N], inv[N];
ll vertex_count[T][N][T], edge_count[T][N]; // vertex_id[i][j][k]: 在 i 个连通块的第 j 号状态中的第 k 个连通块的点数, edge_count[i][j]:在 i 个连通块的第 j 号状态中所有连通块的边数之和。 
ll status_cnt[T], block_size[T];
std::unordered_map<ull, int> map[T];

inline ll fastpow(ll x, ll p)
{
    ll res = 1;
    while (p)
    {
        if (p & 1) res = res * x % mod;
        x = x * x % mod;
        p >>= 1;
    }
    return res;
}

inline void init()
{
    fact[0] = 1;
    for (int i = 1; i < N; ++i) fact[i] = i * fact[i - 1] % mod;
    inv[N - 1] = fastpow(fact[N - 1], mod - 2);
    for (int i = N - 2; i >= 0; --i) inv[i] = (i + 1) * inv[i + 1] % mod;
}

inline void get_id(int total_block_cnt, int rest_vertex, int current_block_cnt, int last_block_size)
{
    int rest_block_cnt = total_block_cnt - current_block_cnt + 1;
    if (rest_block_cnt * last_block_size > rest_vertex) return;
    else if (current_block_cnt == total_block_cnt + 1)
    {
        ++status_cnt[total_block_cnt];
        int id = status_cnt[total_block_cnt];
        ull hash_value = 0;
        for (int i = 1; i <= total_block_cnt; ++i)
        {
            vertex_count[total_block_cnt][id][i] = block_size[i];
            edge_count[total_block_cnt][id] += block_size[i] * (block_size[i] - 1) >> 1;
            hash_value = hash_value * base + block_size[i];
        }
        map[total_block_cnt][hash_value] = id;
    }
    else if (current_block_cnt == total_block_cnt)
    {
        block_size[current_block_cnt] = rest_vertex;
        get_id(total_block_cnt, 0, current_block_cnt + 1, rest_vertex);
    }
    else
    {
        for (int i = last_block_size; i <= rest_vertex; ++i)
        {
            block_size[current_block_cnt] = i;
            get_id(total_block_cnt, rest_vertex - i, current_block_cnt + 1, i);
        }
    }
}

inline ll Comb(ll n, ll m)
{
    if (n < m) return 0;
    if (n < 0) return 0;
    return fact[n] * inv[m] % mod * inv[n - m] % mod;
}

inline ll Perm(ll n, ll m)
{
    if (n < m) return 0;
    if (n < 0) return 0;
    return fact[n] * inv[n - m] % mod;    
}

inline ll dp()
{
    static ll f[T][N], g[T][T]; // f[i][j] 表示连出 i 个连通块中的 j 号状态的方案数。
    static ll vc[N];
    for (int i = 1; i <= status_cnt[n]; ++i) f[n][i] = 1;
    for (int i = n; i > 1; --i)
    {
        for (int j = 1; j <= status_cnt[i]; ++j)
        {
            ll connect_nontree_edge_factor = Perm(edge_count[i][j] - a[i], a[i - 1] - a[i] - 1);
            memset(g, 0, sizeof g);

            // 枚举 k,l 两个点集将其 merge 
            for (int x = 1; x <= i; ++x) vc[x] = vertex_count[i][j][x];
            for (int k = 1; k <= i; ++k)
            {
                for (int l = k + 1; l <= i; ++l)
                {
                    if (!g[vc[k]][vc[l]])
                    {
                        ull new_hash_value = 0;
                        bool ok = true;
                        for (int w = 1; w <= i; ++w)
                        {
                            if (w != k && w != l)
                            {
                                if (ok && vc[w] > vc[k] + vc[l])
                                {
                                    ok = false;
                                    new_hash_value = new_hash_value * base + vc[k] + vc[l];
                                }
                                new_hash_value = new_hash_value * base + vc[w];
                            }
                        }
                        if (ok) new_hash_value = new_hash_value * base + vc[k] + vc[l];
                        g[vc[k]][vc[l]] = map[i - 1][new_hash_value];
                    }
                    // 转移至 g[vc[k]][vc[l]]
                    f[i - 1][g[vc[k]][vc[l]]] += f[i][j] * connect_nontree_edge_factor % mod * vc[k] * vc[l] % mod;
                    f[i - 1][g[vc[k]][vc[l]]] %= mod; 
                }
            }
        }
    }
    // 剩余 C(n, 2) - a[1] 条边随便连接不会影响答案。
    return f[1][1] * fact[Comb(n, 2) - a[1]] % mod;
}

inline char nc()
{
    static char buf[1000000], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? EOF : *p1++;
}

inline int read()
{
    int res = 0;
    char ch;
    do ch = nc(); while (ch < 48 or ch > 57);
    do res = res * 10 + ch - 48, ch = nc(); while (ch >= 48 && ch <= 57);
    return res;
}

int main()
{
    init();
    n = read();
    for (int i = 1; i < n; ++i) a[n - i] = read();
    for (int i = 1; i <= n; ++i) get_id(i, n, 1, 1);
    printf("%lld", dp());
    return 0;
}

所以是一个套路题一个休闲的数据结构和一个提高组dp

评论

DoorKickers
@mike
  • 2020-02-16 09:06:14
  • Reply

发表评论

可以用@mike来提到mike这个用户,mike会被高亮显示。如果你真的想打“@”这个字符,请用“@@”。