UNIQUE VISION Programming Contest 2024 Summer (AtCoder Beginner Contest 359) G問題メモ
G - Sum of Tree Distance
問題
$N$ 頂点の木があり、頂点 $v$ には色 $A_v$ が塗られている
$d(u,v)$ を以下で定義する
$\displaystyle \sum_{u=1}^{N-1} \sum_{v=u+1}^{N} d(u,v)$ を求めよ
$2 \le N \le 2 \times 10^5$
解法1 - 平方根で場合分け
同じ色の頂点が何個あるかで色分けする。
閾値 $B$ を決める。
$B$ 個未満の色
同じ色が $B$ 個未満の場合、愚直に全ペアの距離を1つ1つ求める。
オイラーツアー+区間最小値が取得できるSparseTable を前計算しておくことにより、2頂点間の距離は $O(1)$ で求められる。
(言語によってはbit_lengthを求める際に軽い $O(\log{N})$ がかかるが)
1つの色につき最大 $O(B^2)$、全体では $O(B^2 \dfrac{N}{B})$ かかる。
区間最小値をセグ木でやるとlogがつき、若干TLEが厳しくなる。
$B$ 個以上の色
同じ色が $B$ 個以上の頂点は、木DPする。
1つの色につき $O(N)$、全体では $O(N \dfrac{N}{B})$ かかる。
この2つのバランスが取れるのは、$B = \sqrt{N}$ の時で、全体 $O(N \sqrt{N})$ となる。
PyPyだと数個のテストケースでTLEが取れなかった。
Numbaに書き直すと余裕を持って通った。(約2000ms)
Python3
import os
import sys
import numpy as np
def solve(inp):
def bit_length(n):
x = 0
while n > 0:
n >>= 1
x += 1
return x
n = inp[0]
uuu = inp[1:n * 2 - 1:2] - 1
vvv = inp[2:n * 2 - 1:2] - 1
aaa = inp[n * 2 - 1:]
int_list = [n]
int_list.clear()
links = [int_list.copy() for _ in range(n)]
for i in range(n - 1):
links[uuu[i]].append(vvv[i])
links[vvv[i]].append(uuu[i])
# 行きがけ順を決定
preorder = np.zeros(n, np.int64)
euler_tour = np.zeros(n * 2 - 1, np.int64)
euler_tour_first_indices = np.zeros(n, np.int64)
poi = 0
eti = 0
depth = np.zeros(n, np.int64)
parents = np.full(n, -1, np.int64)
progress = np.zeros(n, np.int64)
q = [0]
while q:
u = q[-1]
if progress[u] == 0:
preorder[poi] = u
poi += 1
euler_tour_first_indices[u] = eti
euler_tour[eti] = u
eti += 1
if progress[u] >= len(links[u]):
q.pop()
continue
v = links[u][progress[u]]
progress[u] += 1
parents[v] = u
links[v].remove(u) # 何回もDPするので、親への辺は除いておく
depth[v] = depth[u] + 1
q.append(v)
# LCAを求めるためのSparseTableを構築
depth_by_euler_tour = np.zeros_like(euler_tour, np.int64)
m = n * 2 - 1
for i in range(m):
depth_by_euler_tour[i] = depth[euler_tour[i]]
log_m = bit_length(m - 1) + 1
lca_sparce_table = np.zeros((log_m, m), np.int64)
lca_sparce_table[0] = depth_by_euler_tour
for i in range(1, log_m):
width = 1 << (i - 1)
for j in range(m - width * 2 + 1):
lca_sparce_table[i, j] = min(lca_sparce_table[i - 1, j], lca_sparce_table[i - 1, j + width])
dp_order = preorder[::-1]
# a ごとに出現位置を整理
aaa_counter = {}
aaa_indices = []
for i in range(n):
a = aaa[i]
if a in aaa_counter:
idx = aaa_counter[a]
aaa_indices[idx].append(i)
else:
aaa_counter[a] = len(aaa_indices)
aaa_indices.append([i])
def solve_by_dp(a):
ans = 0
dp1 = np.zeros(n, np.int64) # 延べ距離
dp2 = np.zeros(n, np.int64) # 個数
for i in range(n):
u = dp_order[i]
x1 = 0
x2 = 0
for v in links[u]:
y2 = dp2[v]
y1 = dp1[v] + y2
ans += y1 * x2 + y2 * x1
x1 += y1
x2 += y2
if aaa[u] == a:
ans += x1
x2 += 1
dp1[u] = x1
dp2[u] = x2
return ans
def solve_by_lca(u, v):
l = euler_tour_first_indices[u]
r = euler_tour_first_indices[v]
if l > r:
l, r = r, l
r += 1
d = r - l
# if d == 1:
# lca_depth = lca_sparce_table[0, l]
k = bit_length(d - 1) - 1
k2 = 1 << k
lca_depth = min(lca_sparce_table[k, l], lca_sparce_table[k, r - k2])
result = depth[u] + depth[v] - 2 * lca_depth
return result
thr = 450
ans = 0
for a, idx in aaa_counter.items():
lst = aaa_indices[idx]
m = len(lst)
if m == 1:
continue
elif m > thr:
ans += solve_by_dp(a)
else:
for i in range(m):
u = lst[i]
for j in range(i + 1, m):
v = lst[j]
ans += solve_by_lca(u, v)
return ans
SIGNATURE = '(i8[:],)'
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', SIGNATURE)(solve)
cc.compile()
exit()
if os.name == 'posix':
# noinspection PyUnresolvedReferences
from my_module import solve
else:
from numba import njit
solve = njit(SIGNATURE, cache=True)(solve)
print('compiled', file=sys.stderr)
inp = np.fromstring(sys.stdin.read(), dtype=np.int64, sep=' ')
ans = solve(inp)
print(ans)
解法2 - 木DP+マージテク
距離計算の部分にちょっとした工夫が必要になるが、こちらの方が実装量は少ない。
ただ、ライングラフ+色の種類が多いテストケースに対して $O(N^2)$ の空間計算量が必要になるのでMLEの危険性がある?
(まぁ、定数倍が大きいわけではないし、MLEの制限はTLEと比べると緩いので大丈夫か)
1回の木DPで、全ての色について求める。
例えば解法1のDPでは「頂点 $v$ までの距離の総和」を値として持ったため、頂点を遡る毎に更新が必要になった。
しかし、マージテクをする場合は、小さい方を大きい方にマージする分、大きい方は更新が発生してはいけない。
「深さ」を値として持つことで、距離を求めることを可能にしつつ、更新が発生しないようにできる。
「子孫 $u$ から祖先 $v$ までの距離」=「$u$ の深さ - $v$ の深さ」
これで、$a$ の種類数が少ない方を多い方にマージしていくことで、時間計算量は $O(N \log{N})$ となる。
DPの値は、配列で持つと明確に $O(N^2)$ の空間を要するため、基本的には辞書で持つことになる。これに log がかかる言語では、$O(N \log^2{N})$ となる。
Python
def solve(n, links, aaa):
ans = 0
parents = [-1] * n
depths = [0] * n
counts = [{} for _ in range(n)]
depth_sum = [{} for _ in range(n)]
state = [0] * n
q = [0]
while q:
u = q[-1]
if state[u] == 0:
state[u] = 1
for v in links[u]:
if parents[u] == v:
continue
parents[v] = u
depths[v] = depths[u] + 1
q.append(v)
continue
q.pop()
du = depths[u]
cnt_u = counts[u]
sum_u = depth_sum[u]
for v in links[u]:
cnt_v = counts[v]
sum_v = depth_sum[v]
if len(cnt_u) < len(cnt_v):
cnt_u, cnt_v = cnt_v, cnt_u
sum_u, sum_v = sum_v, sum_u
for a, cv in cnt_v.items():
if a in cnt_u:
sv = sum_v[a]
cu = cnt_u[a]
su = sum_u[a]
ans += (sv - du * cv) * cu
ans += (su - du * cu) * cv
cnt_u[a] += cv
sum_u[a] += sv
else:
cnt_u[a] = cv
sum_u[a] = sum_v[a]
a = aaa[u]
if a in cnt_u:
ans += sum_u[a] - du * cnt_u[a]
cnt_u[a] += 1
sum_u[a] += du
else:
cnt_u[a] = 1
sum_u[a] = du
counts[u] = cnt_u
depth_sum[u] = sum_u
return ans
n = int(input())
links = [set() for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
u -= 1
v -= 1
links[u].add(v)
links[v].add(u)
aaa = list(map(int, input().split()))
ans = solve(n, links, aaa)
print(ans)
その他の解法