import
os
import
sys
import
numpy as np
def
solve(inp):
def
bit_length(n):
x
=
0
while
n >
0
:
n >>
=
1
x
+
=
1
return
x
n
=
inp[
0
]
uuu
=
inp[
1
:n
*
2
-
1
:
2
]
-
1
vvv
=
inp[
2
:n
*
2
-
1
:
2
]
-
1
aaa
=
inp[n
*
2
-
1
:]
int_list
=
[n]
int_list.clear()
links
=
[int_list.copy()
for
_
in
range
(n)]
for
i
in
range
(n
-
1
):
links[uuu[i]].append(vvv[i])
links[vvv[i]].append(uuu[i])
preorder
=
np.zeros(n, np.int64)
euler_tour
=
np.zeros(n
*
2
-
1
, np.int64)
euler_tour_first_indices
=
np.zeros(n, np.int64)
poi
=
0
eti
=
0
depth
=
np.zeros(n, np.int64)
parents
=
np.full(n,
-
1
, np.int64)
progress
=
np.zeros(n, np.int64)
q
=
[
0
]
while
q:
u
=
q[
-
1
]
if
progress[u]
=
=
0
:
preorder[poi]
=
u
poi
+
=
1
euler_tour_first_indices[u]
=
eti
euler_tour[eti]
=
u
eti
+
=
1
if
progress[u] >
=
len
(links[u]):
q.pop()
continue
v
=
links[u][progress[u]]
progress[u]
+
=
1
parents[v]
=
u
links[v].remove(u)
depth[v]
=
depth[u]
+
1
q.append(v)
depth_by_euler_tour
=
np.zeros_like(euler_tour, np.int64)
m
=
n
*
2
-
1
for
i
in
range
(m):
depth_by_euler_tour[i]
=
depth[euler_tour[i]]
log_m
=
bit_length(m
-
1
)
+
1
lca_sparce_table
=
np.zeros((log_m, m), np.int64)
lca_sparce_table[
0
]
=
depth_by_euler_tour
for
i
in
range
(
1
, log_m):
width
=
1
<< (i
-
1
)
for
j
in
range
(m
-
width
*
2
+
1
):
lca_sparce_table[i, j]
=
min
(lca_sparce_table[i
-
1
, j], lca_sparce_table[i
-
1
, j
+
width])
dp_order
=
preorder[::
-
1
]
aaa_counter
=
{}
aaa_indices
=
[]
for
i
in
range
(n):
a
=
aaa[i]
if
a
in
aaa_counter:
idx
=
aaa_counter[a]
aaa_indices[idx].append(i)
else
:
aaa_counter[a]
=
len
(aaa_indices)
aaa_indices.append([i])
def
solve_by_dp(a):
ans
=
0
dp1
=
np.zeros(n, np.int64)
dp2
=
np.zeros(n, np.int64)
for
i
in
range
(n):
u
=
dp_order[i]
x1
=
0
x2
=
0
for
v
in
links[u]:
y2
=
dp2[v]
y1
=
dp1[v]
+
y2
ans
+
=
y1
*
x2
+
y2
*
x1
x1
+
=
y1
x2
+
=
y2
if
aaa[u]
=
=
a:
ans
+
=
x1
x2
+
=
1
dp1[u]
=
x1
dp2[u]
=
x2
return
ans
def
solve_by_lca(u, v):
l
=
euler_tour_first_indices[u]
r
=
euler_tour_first_indices[v]
if
l > r:
l, r
=
r, l
r
+
=
1
d
=
r
-
l
k
=
bit_length(d
-
1
)
-
1
k2
=
1
<< k
lca_depth
=
min
(lca_sparce_table[k, l], lca_sparce_table[k, r
-
k2])
result
=
depth[u]
+
depth[v]
-
2
*
lca_depth
return
result
thr
=
450
ans
=
0
for
a, idx
in
aaa_counter.items():
lst
=
aaa_indices[idx]
m
=
len
(lst)
if
m
=
=
1
:
continue
elif
m > thr:
ans
+
=
solve_by_dp(a)
else
:
for
i
in
range
(m):
u
=
lst[i]
for
j
in
range
(i
+
1
, m):
v
=
lst[j]
ans
+
=
solve_by_lca(u, v)
return
ans
SIGNATURE
=
'(i8[:],)'
if
sys.argv[
-
1
]
=
=
'ONLINE_JUDGE'
:
from
numba.pycc
import
CC
cc
=
CC(
'my_module'
)
cc.export(
'solve'
, SIGNATURE)(solve)
cc.
compile
()
exit()
if
os.name
=
=
'posix'
:
from
my_module
import
solve
else
:
from
numba
import
njit
solve
=
njit(SIGNATURE, cache
=
True
)(solve)
print
(
'compiled'
,
file
=
sys.stderr)
inp
=
np.fromstring(sys.stdin.read(), dtype
=
np.int64, sep
=
' '
)
ans
=
solve(inp)
print
(ans)