题解 P5666 【树的重心【民间数据】】
Kinandra
2019-11-17 22:34:13
本做法洛谷上可以通过, 发篇题解攒攒人品.
标签:主席树, dfs序, 树状数组.
容易想到对每个点分别求贡献(即求每个点割去那些边会成为重心), 我们选择任意一个点(就是 $1$号节点辣!)将其作为根节点.
模仿题目中的描述, 定义 $u$ 的 **分裂子树** 表示在树中删去 $u$ 及与它关联的边后, 分裂出的子树, 那么一条边 $e$ 对一个点 $u$ 有贡献, 当且仅当割去 $e$ 后, 考虑 $u$ 所在的一部分(记这部分的大小为 $tmp$, 记另一部分为**割去的部分**), $u$ 最大的**分裂子树** 的大小不大于 $\lfloor
\frac{tmp}2\rfloor$.
#### Part 1
首先我们先考虑**根结点**的贡献, 我们先求出**根节点**的每一棵儿子的子树的大小 $size$, 取 $size$ 最大的儿子记为 $mx$.
显然如果我们割掉的边不在 $mx$ 的子树内, $u$ 最大的**分裂子树** 的大小不会改变, 那么这条边有贡献需要满足$size_{mx}\leqslant \lfloor\frac{n-t}{2}\rfloor$, 其中 $t$ 表示**割去的部分** 的大小, 稍加变形后得到 $t\leqslant n-2\times size_{mx}$.
考虑割掉的边在 $mx$ 的子树内, $u$ 的最大 **分裂子树** 可能会变成原来次大的(记为 $mx'$ ), 那么我们需要满足 $size_{mx}-t\leqslant \lfloor\frac{n-t}{2}\rfloor, size_{mx'}\leqslant \lfloor\frac{n-t}{2}\rfloor$, 稍加变形后得到 $2\times size_{mx}-n\leqslant t\leqslant n-2\times size_{mx'}$.
所以我们需要知道 $t$ 取每个值的方案数, 显然 $t$ 的一个取值对应了一个节点子树的大小, 我们对每个**根**的儿子求出其子树内 $t$ 的取值情况, 在分别求出合法的 $t$ 的方案就好了.
复杂度 $\mathcal O(n)$, 事实上枚举每个点做根求一遍答案可以做到 $\mathcal O(n^2)$.
#### Part 2
接下来考虑其他点的贡献, 发现其他点与根节点唯一的区别就是具有 **外子树** (指断掉与父亲节点的连边的那棵子树), 那么该如何维护外子树的 $t$ 的取值呢?
发现我们可以很容易地在 $\mathcal O(n\log n)$ 的时间内通过 线段树合并, 主席树+dfs序 等方法来维护 **内子树** (与**外子树**相对)的 $t$ 的取值, 这里不赘述, 那么是否可以维护**整棵树**对于某个点的 $t$ 的取值呢? 这样减去**内子树**的部分, 就可以得到**外子树**的 $t$ 的取值了.
事实上这是可以维护的, 首先整棵树对于**根**结点的 $t$ 取值就是处根节点外所有点的子树大小. 考虑我们已知整棵树对点 $u$ 的情况, 如何求 $u$ 的一个儿子 $v$ 的情况呢? 发现 $u, v$ 断去某条边的 $t$ 不同当且仅当这条边是 $(u,v)$, 所以 $v$ 的 $t$ 的取值比 $u$ 多了一个 $n-sz_v$, 少了一个 $sz_v$. 于是我们可以边dfs边维护整棵树对于某个点的 $t$ 的取值了, 这里要用到树状数组/线段树, 复杂度是 $\mathcal O(n\log n)$ 的.
时间复杂度 $\mathcal O(n\log n)$ , 空间复杂度 $\mathcal O(n\log n)$ .
~~代码不敢贴, 怕被hack~~
upd, 把代码补上, 不然可能看不太懂哈.
```cpp
#include <bits/stdc++.h>
using namespace std;
int read();
int n;
int hd[300005], nx[600005], to[600005], cnt;
void add(int f, int t) { nx[++cnt] = hd[f], hd[f] = cnt, to[cnt] = t; }
struct Bit {
int ts[300005];
void add(int p, int v) {
for (; p <= n; p += p & (-p)) ts[p] += v;
}
int qry(int l, int r) {
if (l > r) return 0;
l--;
int rt = 0;
while (l != r) {
if (l > r) rt -= ts[l], l -= l & (-l);
if (r > l) rt += ts[r], r -= r & (-r);
}
return rt;
}
} bit;
int sz[300005], mx[300005][2];
struct P {
int u, l, r;
} w[600005];
int root[300005];
struct Pseg {
int idcnt;
int ls[10000007], rs[10000007], ts[10000007];
void add(int &x, int k, int l, int r, int p) {
ts[x = ++idcnt] = ts[k] + 1, ls[x] = ls[k], rs[x] = rs[k];
if (l == r) return;
int mid = l + r >> 1;
(p <= mid) ? add(ls[x], ls[k], l, mid, p)
: add(rs[x], rs[k], mid + 1, r, p);
}
int qry(int k1, int k2, int l, int r, int st, int en) {
if (st > r || en < l) return 0;
if (st <= l && en >= r) return ts[k2] - ts[k1];
int mid = l + r >> 1;
return qry(ls[k1], ls[k2], l, mid, st, en) + qry(rs[k1], rs[k2], mid + 1, r, st, en);
}
} seg;
long long res;
int l[300005], r[300005], L[300005], R[300005];
int pre[300005], pst[300005], frt[300005], dfn;
int fa[300005];
void dfs1(int u) {
sz[u] = 1, root[pre[u] = ++dfn] = 0, frt[dfn] = u, mx[u][0] = mx[u][1] = 0;
int tmx = 0, tci = 0;
for (int i = hd[u], v; i; i = nx[i]) {
if ((v = to[i]) == fa[u]) continue;
fa[v] = u, dfs1(v), sz[u] += sz[v];
if (sz[mx[u][1]] < sz[v]) mx[u][1] = v;
if (sz[mx[u][0]] < sz[mx[u][1]]) swap(mx[u][0], mx[u][1]);
}
if (fa[u]) bit.add(sz[u], 1);
tmx = sz[mx[u][0]], tci = sz[mx[u][1]];
if (sz[mx[u][1]] < n - sz[u]) {
mx[u][1] = fa[u], tci = n - sz[u];
if (sz[mx[u][0]] < n - sz[u])
swap(mx[u][0], mx[u][1]), swap(tmx, tci);
}
L[u] = 1, R[u] = n - tmx * 2;
l[u] = max(1, 2 * tmx - n), r[u] = n - 2 * tci;
pst[u] = dfn;
}
void dfs2(int u) {
res += 1ll * u * ((mx[u][0] == fa[u]) ?
bit.qry(l[u], r[u]) : bit.qry(L[u], R[u]));
for (int i = hd[u]; i; i = nx[i])
if (to[i] != fa[u]) {
bit.add(sz[to[i]], -1), bit.add(n - sz[to[i]], 1);
dfs2(to[i]);
bit.add(sz[to[i]], 1), bit.add(n - sz[to[i]], -1);
}
}
int main() {
int T = read();
while (T--) {
n = read(), cnt = 0, res = 0, dfn = 0, seg.idcnt = 0;
for (int i = 1; i <= n; ++i) hd[i] = 0;
for (int i = 1, u, v; i < n; ++i)
u = read(), v = read(), add(u, v), add(v, u);
for (int i = 1; i <= n; ++i) bit.ts[i] = 0;
dfs1(1), dfs2(1);
for (int i = 1; i <= n; ++i)
seg.add(root[i], root[i - 1], 1, n, sz[frt[i]]);
for (int u = 1; u <= n; ++u) {
if (mx[u][0] != fa[u]) {
for (int i = hd[u], v; i; i = nx[i]) {
if ((v = to[i]) == mx[u][0]) {
res -= 1ll * u *
seg.qry(root[pre[v] - 1], root[pst[v]], 1, n, L[u], R[u]);
res += 1ll * u *
seg.qry(root[pre[v] - 1], root[pst[v]], 1, n, l[u], r[u]);
}
}
} else {
res -= 1ll * u *
seg.qry(root[pre[u]], root[pst[u]], 1, n, l[u], r[u]);
res += 1ll * u *
seg.qry(root[pre[u]], root[pst[u]], 1, n, L[u], R[u]);
}
}
printf("%lld\n", res);
}
return 0;
}
int read() {
int x = 0, f = 1;
char c = getchar();
while (!isdigit(c)) f = (c == '-') ? -1 : f, c = getchar();
while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
return x * f;
}
```