import os
import sys
import numpy as np
def solve(inp):
def bit_length(n):
x = 0
while n > 0:
n >>= 1
x += 1
return x
n = inp[0]
uuu = inp[1:n * 2 - 1:2] - 1
vvv = inp[2:n * 2 - 1:2] - 1
aaa = inp[n * 2 - 1:]
int_list = [n]
int_list.clear()
links = [int_list.copy() for _ in range(n)]
for i in range(n - 1):
links[uuu[i]].append(vvv[i])
links[vvv[i]].append(uuu[i])
# 行きがけ順を決定
preorder = np.zeros(n, np.int64)
euler_tour = np.zeros(n * 2 - 1, np.int64)
euler_tour_first_indices = np.zeros(n, np.int64)
poi = 0
eti = 0
depth = np.zeros(n, np.int64)
parents = np.full(n, -1, np.int64)
progress = np.zeros(n, np.int64)
q = [0]
while q:
u = q[-1]
if progress[u] == 0:
preorder[poi] = u
poi += 1
euler_tour_first_indices[u] = eti
euler_tour[eti] = u
eti += 1
if progress[u] >= len(links[u]):
q.pop()
continue
v = links[u][progress[u]]
progress[u] += 1
parents[v] = u
links[v].remove(u) # 何回もDPするので、親への辺は除いておく
depth[v] = depth[u] + 1
q.append(v)
# LCAを求めるためのSparseTableを構築
depth_by_euler_tour = np.zeros_like(euler_tour, np.int64)
m = n * 2 - 1
for i in range(m):
depth_by_euler_tour[i] = depth[euler_tour[i]]
log_m = bit_length(m - 1) + 1
lca_sparce_table = np.zeros((log_m, m), np.int64)
lca_sparce_table[0] = depth_by_euler_tour
for i in range(1, log_m):
width = 1 << (i - 1)
for j in range(m - width * 2 + 1):
lca_sparce_table[i, j] = min(lca_sparce_table[i - 1, j], lca_sparce_table[i - 1, j + width])
dp_order = preorder[::-1]
# a ごとに出現位置を整理
aaa_counter = {}
aaa_indices = []
for i in range(n):
a = aaa[i]
if a in aaa_counter:
idx = aaa_counter[a]
aaa_indices[idx].append(i)
else:
aaa_counter[a] = len(aaa_indices)
aaa_indices.append([i])
def solve_by_dp(a):
ans = 0
dp1 = np.zeros(n, np.int64) # 延べ距離
dp2 = np.zeros(n, np.int64) # 個数
for i in range(n):
u = dp_order[i]
x1 = 0
x2 = 0
for v in links[u]:
y2 = dp2[v]
y1 = dp1[v] + y2
ans += y1 * x2 + y2 * x1
x1 += y1
x2 += y2
if aaa[u] == a:
ans += x1
x2 += 1
dp1[u] = x1
dp2[u] = x2
return ans
def solve_by_lca(u, v):
l = euler_tour_first_indices[u]
r = euler_tour_first_indices[v]
if l > r:
l, r = r, l
r += 1
d = r - l
# if d == 1:
# lca_depth = lca_sparce_table[0, l]
k = bit_length(d - 1) - 1
k2 = 1 << k
lca_depth = min(lca_sparce_table[k, l], lca_sparce_table[k, r - k2])
result = depth[u] + depth[v] - 2 * lca_depth
return result
thr = 450
ans = 0
for a, idx in aaa_counter.items():
lst = aaa_indices[idx]
m = len(lst)
if m == 1:
continue
elif m > thr:
ans += solve_by_dp(a)
else:
for i in range(m):
u = lst[i]
for j in range(i + 1, m):
v = lst[j]
ans += solve_by_lca(u, v)
return ans
SIGNATURE = '(i8[:],)'
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', SIGNATURE)(solve)
cc.compile()
exit()
if os.name == 'posix':
# noinspection PyUnresolvedReferences
from my_module import solve
else:
from numba import njit
solve = njit(SIGNATURE, cache=True)(solve)
print('compiled', file=sys.stderr)
inp = np.fromstring(sys.stdin.read(), dtype=np.int64, sep=' ')
ans = solve(inp)
print(ans)