题解 P5666 【树的重心【民间数据】】

Kinandra

2019-11-17 22:34:13

Solution

本做法洛谷上可以通过, 发篇题解攒攒人品. 标签:主席树, dfs序, 树状数组. 容易想到对每个点分别求贡献(即求每个点割去那些边会成为重心), 我们选择任意一个点(就是 $1$号节点辣!)将其作为根节点. 模仿题目中的描述, 定义 $u$ 的 **分裂子树** 表示在树中删去 $u$ 及与它关联的边后, 分裂出的子树, 那么一条边 $e$ 对一个点 $u$ 有贡献, 当且仅当割去 $e$ 后, 考虑 $u$ 所在的一部分(记这部分的大小为 $tmp$, 记另一部分为**割去的部分**), $u$ 最大的**分裂子树** 的大小不大于 $\lfloor \frac{tmp}2\rfloor$. #### Part 1 首先我们先考虑**根结点**的贡献, 我们先求出**根节点**的每一棵儿子的子树的大小 $size$, 取 $size$ 最大的儿子记为 $mx$. 显然如果我们割掉的边不在 $mx$ 的子树内, $u$ 最大的**分裂子树** 的大小不会改变, 那么这条边有贡献需要满足$size_{mx}\leqslant \lfloor\frac{n-t}{2}\rfloor$, 其中 $t$ 表示**割去的部分** 的大小, 稍加变形后得到 $t\leqslant n-2\times size_{mx}$. 考虑割掉的边在 $mx$ 的子树内, $u$ 的最大 **分裂子树** 可能会变成原来次大的(记为 $mx'$ ), 那么我们需要满足 $size_{mx}-t\leqslant \lfloor\frac{n-t}{2}\rfloor, size_{mx'}\leqslant \lfloor\frac{n-t}{2}\rfloor$, 稍加变形后得到 $2\times size_{mx}-n\leqslant t\leqslant n-2\times size_{mx'}$. 所以我们需要知道 $t$ 取每个值的方案数, 显然 $t$ 的一个取值对应了一个节点子树的大小, 我们对每个**根**的儿子求出其子树内 $t$ 的取值情况, 在分别求出合法的 $t$ 的方案就好了. 复杂度 $\mathcal O(n)$, 事实上枚举每个点做根求一遍答案可以做到 $\mathcal O(n^2)$. #### Part 2 接下来考虑其他点的贡献, 发现其他点与根节点唯一的区别就是具有 **外子树** (指断掉与父亲节点的连边的那棵子树), 那么该如何维护外子树的 $t$ 的取值呢? 发现我们可以很容易地在 $\mathcal O(n\log n)$ 的时间内通过 线段树合并, 主席树+dfs序 等方法来维护 **内子树** (与**外子树**相对)的 $t$ 的取值, 这里不赘述, 那么是否可以维护**整棵树**对于某个点的 $t$ 的取值呢? 这样减去**内子树**的部分, 就可以得到**外子树**的 $t$ 的取值了. 事实上这是可以维护的, 首先整棵树对于**根**结点的 $t$ 取值就是处根节点外所有点的子树大小. 考虑我们已知整棵树对点 $u$ 的情况, 如何求 $u$ 的一个儿子 $v$ 的情况呢? 发现 $u, v$ 断去某条边的 $t$ 不同当且仅当这条边是 $(u,v)$, 所以 $v$ 的 $t$ 的取值比 $u$ 多了一个 $n-sz_v$, 少了一个 $sz_v$. 于是我们可以边dfs边维护整棵树对于某个点的 $t$ 的取值了, 这里要用到树状数组/线段树, 复杂度是 $\mathcal O(n\log n)$ 的. 时间复杂度 $\mathcal O(n\log n)$ , 空间复杂度 $\mathcal O(n\log n)$ . ~~代码不敢贴, 怕被hack~~ upd, 把代码补上, 不然可能看不太懂哈. ```cpp #include <bits/stdc++.h> using namespace std; int read(); int n; int hd[300005], nx[600005], to[600005], cnt; void add(int f, int t) { nx[++cnt] = hd[f], hd[f] = cnt, to[cnt] = t; } struct Bit { int ts[300005]; void add(int p, int v) { for (; p <= n; p += p & (-p)) ts[p] += v; } int qry(int l, int r) { if (l > r) return 0; l--; int rt = 0; while (l != r) { if (l > r) rt -= ts[l], l -= l & (-l); if (r > l) rt += ts[r], r -= r & (-r); } return rt; } } bit; int sz[300005], mx[300005][2]; struct P { int u, l, r; } w[600005]; int root[300005]; struct Pseg { int idcnt; int ls[10000007], rs[10000007], ts[10000007]; void add(int &x, int k, int l, int r, int p) { ts[x = ++idcnt] = ts[k] + 1, ls[x] = ls[k], rs[x] = rs[k]; if (l == r) return; int mid = l + r >> 1; (p <= mid) ? add(ls[x], ls[k], l, mid, p) : add(rs[x], rs[k], mid + 1, r, p); } int qry(int k1, int k2, int l, int r, int st, int en) { if (st > r || en < l) return 0; if (st <= l && en >= r) return ts[k2] - ts[k1]; int mid = l + r >> 1; return qry(ls[k1], ls[k2], l, mid, st, en) + qry(rs[k1], rs[k2], mid + 1, r, st, en); } } seg; long long res; int l[300005], r[300005], L[300005], R[300005]; int pre[300005], pst[300005], frt[300005], dfn; int fa[300005]; void dfs1(int u) { sz[u] = 1, root[pre[u] = ++dfn] = 0, frt[dfn] = u, mx[u][0] = mx[u][1] = 0; int tmx = 0, tci = 0; for (int i = hd[u], v; i; i = nx[i]) { if ((v = to[i]) == fa[u]) continue; fa[v] = u, dfs1(v), sz[u] += sz[v]; if (sz[mx[u][1]] < sz[v]) mx[u][1] = v; if (sz[mx[u][0]] < sz[mx[u][1]]) swap(mx[u][0], mx[u][1]); } if (fa[u]) bit.add(sz[u], 1); tmx = sz[mx[u][0]], tci = sz[mx[u][1]]; if (sz[mx[u][1]] < n - sz[u]) { mx[u][1] = fa[u], tci = n - sz[u]; if (sz[mx[u][0]] < n - sz[u]) swap(mx[u][0], mx[u][1]), swap(tmx, tci); } L[u] = 1, R[u] = n - tmx * 2; l[u] = max(1, 2 * tmx - n), r[u] = n - 2 * tci; pst[u] = dfn; } void dfs2(int u) { res += 1ll * u * ((mx[u][0] == fa[u]) ? bit.qry(l[u], r[u]) : bit.qry(L[u], R[u])); for (int i = hd[u]; i; i = nx[i]) if (to[i] != fa[u]) { bit.add(sz[to[i]], -1), bit.add(n - sz[to[i]], 1); dfs2(to[i]); bit.add(sz[to[i]], 1), bit.add(n - sz[to[i]], -1); } } int main() { int T = read(); while (T--) { n = read(), cnt = 0, res = 0, dfn = 0, seg.idcnt = 0; for (int i = 1; i <= n; ++i) hd[i] = 0; for (int i = 1, u, v; i < n; ++i) u = read(), v = read(), add(u, v), add(v, u); for (int i = 1; i <= n; ++i) bit.ts[i] = 0; dfs1(1), dfs2(1); for (int i = 1; i <= n; ++i) seg.add(root[i], root[i - 1], 1, n, sz[frt[i]]); for (int u = 1; u <= n; ++u) { if (mx[u][0] != fa[u]) { for (int i = hd[u], v; i; i = nx[i]) { if ((v = to[i]) == mx[u][0]) { res -= 1ll * u * seg.qry(root[pre[v] - 1], root[pst[v]], 1, n, L[u], R[u]); res += 1ll * u * seg.qry(root[pre[v] - 1], root[pst[v]], 1, n, l[u], r[u]); } } } else { res -= 1ll * u * seg.qry(root[pre[u]], root[pst[u]], 1, n, l[u], r[u]); res += 1ll * u * seg.qry(root[pre[u]], root[pst[u]], 1, n, L[u], R[u]); } } printf("%lld\n", res); } return 0; } int read() { int x = 0, f = 1; char c = getchar(); while (!isdigit(c)) f = (c == '-') ? -1 : f, c = getchar(); while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); return x * f; } ```