2つめは、黒/白それぞれで「自身の Light-Child の列で、先頭の色が黒/白であるようなものの、先頭から同じ色が続く限りの $V_i$ の合計値」を管理する。
列 $a$ のどこかを更新する際は、更新前の先頭の色 $C_{a0}^{before}$ と合計 $L_a^{before}$ を保持しておき、
更新後の先頭の色 $C_{a0}^{after}$ と合計 $L_a^{after}$ と比較する。
before と after が変化していたら親にその影響を伝えなければならないが、それは $C,L$ のbefore/afterを元に親の $X_i,Y_i$ を差分更新できる。
その更新によって、親の $V_i$ が変わりうり、親の列 $b$ の $L_b$ も変化しうる。
変更は連鎖的になるが、HLDのため最大 $O(\log{N})$ 回で収まる。
変化がなくなったら伝播を止めていい。
計算量は $O(N \log{N} + Q (\log{N})^2)$ で重くはあるが、
$(\log{N})^2$ の部分の各 $\log{N}$ は、一方はHLDの列数、もう一方は1列当たりの要素数のlog を表すため、
両方が同時に最大となることはなく、何分の1かの定数倍がかかることで、なんとか間に合う。
実際には、木は長さ $1~2$ の細々とした大量のパスに分かれる可能性もあり、
それらに対し全て セグ木や SortedSet を用意するのは却って遅くなる可能性がある。
短いパスは純粋に配列で管理して、イテレートで愚直に取得した方が速いと思い、実装を分けている。
from sortedcontainers import SortedList
class SegmentTreeInjectable:
"""
単位元生成関数 identity_factory と二項演算関数 func を外部注入するセグメント木
[生成]
SegmentTreeInjectable(n, identity_factory, func)
SegmentTreeInjectable.from_array(array, identity_factory, func) # 既存の配列より作成
[関数]
add(i, x) Aiにxを加算
update(i, x) Aiをxに書き換え
get_range(a, b) [a, b) の集約値を得る
get_all() 全ての集約値を得る
get_point(i) Aiを得る
leftmost(a, b, x, ev) [a, b) の範囲で、ev(x, Ai)=True となる最も左の i を得る(前提条件あり)
rightmost(a, b, x, ev) [a, b) の範囲で、ev(x, Ai)=True となる最も右の i を得る(前提条件あり)
debug_print() 深さ毎に整形して出力する
"""
def __init__(self, n, identity_factory, func):
n2 = 1 << (n - 1).bit_length()
# log = 64 - (n - 1).__ctlz__()
self.offset = n2
self.tree = [identity_factory() for _ in range(n2 << 1)]
self.func = func
self.idf = identity_factory
@classmethod
def from_array(cls, arr, identity_factory, func):
""" 既存の配列から生成 """
ins = cls(len(arr), identity_factory, func)
ins.tree[ins.offset:ins.offset + len(arr)] = arr
for i in range(ins.offset - 1, 0, -1):
l = i << 1
r = l + 1
ins.tree[i] = func(ins.tree[l], ins.tree[r])
return ins
def add(self, i, x):
"""
Aiにxを加算
:param i: index (0-indexed)
:param x: add value
"""
i += self.offset
self.tree[i] = self.func(self.tree[i], x)
self.__upstream(i)
def update(self, i, x):
"""
Aiの値をxに更新
:param i: index(0-indexed)
:param x: update value
"""
i += self.offset
self.tree[i] = x
self.__upstream(i)
def __upstream(self, i):
tree = self.tree
func = self.func
while i > 1:
i >>= 1
lch = i << 1
rch = lch | 1
tree[i] = func(tree[lch], tree[rch])
def get_range(self, a, b):
"""
[a, b)の値を得る
:param a: index(0-indexed)
:param b: index(0-indexed)
"""
tree = self.tree
func = self.func
result_l = self.idf()
result_r = self.idf()
l = a + self.offset
r = b + self.offset
while l < r:
if r & 1:
result_r = func(tree[r - 1], result_r)
if l & 1:
result_l = func(result_l, tree[l])
l += 1
l >>= 1
r >>= 1
return func(result_l, result_r)
def get_all(self):
return self.tree[1]
def get_point(self, i):
return self.tree[i + self.offset]
def leftmost(self, a, b, x, ev):
"""
[a, b) の範囲で、ev(x, 値) = True となる最初の index を得る。存在しない場合は-1。
使用できる条件:
[l, r) の集約値を y としたとき、ev(x, y)=True となることが、
l <= i < r 内に ev(x, Ai)=True となる要素があることと等しい。((func, ev) = (min,ge), (max,le) など)
"""
tree = self.tree
l = a + self.offset
r = b + self.offset
r_found = -1
while l < r:
if l & 1:
if ev(x, tree[l]):
return self._leftmost_sub(l, x, ev)
l += 1
if r & 1:
if ev(x, tree[r - 1]):
r_found = r - 1
l >>= 1
r >>= 1
if r_found == -1:
return -1
return self._leftmost_sub(r_found, x, ev)
def _leftmost_sub(self, i, x, ev):
"""
tree-index i が示す範囲で、ev(x, Aj)=True となる最も左のarray-index j を得る
(tree[i] が示す範囲には条件を満たすものが必ず存在する前提とする)
"""
tree = self.tree
while i < self.offset:
l = i << 1
if ev(x, tree[l]):
i = l
else:
i = l + 1
return i - self.offset
def rightmost(self, a, b, x, ev):
"""
[a, b) の範囲で、ev(x, 値) = True となる最後の index を得る。存在しない場合は-1。
使用できる条件:
[l, r) の集約値を y としたとき、ev(x, y)=True となることが、
l <= i < r 内に ev(x, Ai)=True となる要素があることと等しい。((func, ev) = (min,ge), (max,le) など)
"""
tree = self.tree
l = a + self.offset
r = b + self.offset
l_found = -1
while l < r:
if r & 1:
if ev(x, tree[r - 1]):
return self._rightmost_sub(r - 1, x, ev)
if l & 1:
if ev(x, tree[l]):
l_found = l
l += 1
l >>= 1
r >>= 1
if l_found == -1:
return -1
return self._rightmost_sub(l_found, x, ev)
def _rightmost_sub(self, i, x, ev):
"""
tree-index i が示す範囲で、ev(x, Aj)=True となる最も右のarray-index j を得る
(tree[i] が示す範囲には条件を満たすものが必ず存在する前提とする)
"""
tree = self.tree
while i < self.offset:
l = i << 1
if ev(x, tree[l + 1]):
i = l + 1
else:
i = l
return i - self.offset
def debug_print(self):
i = 1
while i <= self.offset:
print(self.tree[i:i * 2])
i <<= 1
def heavy_light_decomposition(
n: int,
links: list[list[int]],
root: int = 0
) -> tuple[list[list[int]], list[int], list[int]]:
"""
①--②--④--⑦--⑨ → [[1,2,4,7,9], [10], [3,5,8], [6]]
`--③--⑤-⑧ `-⑩
`--⑥
Heavy-path のリストに分解する。
Heavy-path の並びは、(最も重い子を最初に訪れるような)オイラーツアーの訪問順となる。
"""
parents = [-1] * n
weights = [-1] * n
q = [root]
while q:
u = q[-1]
if weights[u] == -1:
weights[u] = -2
for v in links[u]:
if parents[u] == v:
continue
parents[v] = u
q.append(v)
else:
q.pop()
weights[u] = 1 + sum(weights[v] for v in links[u] if v != parents[u])
q = [root]
progress = [0] * n
progress2 = [0] * n
current_path = []
result = [current_path]
while q:
u = q[-1]
if progress[u] == 0:
links[u].sort(key=weights.__getitem__, reverse=True)
current_path.append(u)
if progress[u] >= len(links[u]):
q.pop()
continue
v = links[u][progress[u]]
progress[u] += 1
if v == parents[u]:
continue
progress2[u] += 1
if progress2[u] >= 2:
current_path = []
result.append(current_path)
q.append(v)
return result, weights, parents
class Path:
"""
www: 各ノードの重み配列
ccc: 各ノードの色配列({0, 1} を取る)
のペアを管理し、以下の操作を提供するインターフェイス。
"""
def add(self, i, x):
"""
www[i] に x を加算。
親への伝播に必要な (old_color, old_sum_root, new_color, new_sum_root) を返す。
"""
raise NotImplementedError
def add_child(self, i, x_black, x_white):
"""
black[i] に x_black、white[i] に x_white を加算。
(old_color, old_sum_root, new_color, new_sum_root) を返す。
"""
raise NotImplementedError
def is_same(self, i):
""" ccc[0]~ccc[i] が全て同じ値か """
raise NotImplementedError
def sum_root(self):
""" ccc[0] と同じ色が連続する prefix の combined の総和 """
raise NotImplementedError
def sum_point(self, i):
""" i を含む同色連続区間の combined の総和 """
raise NotImplementedError
def switch(self, i):
"""
ccc[i] の値を反転。
(old_color, old_sum_root, new_color, new_sum_root) を返す。
"""
raise NotImplementedError
class SegTreePath(Path):
def __init__(self, www, ccc):
self.n = len(www)
self.www = list(www)
self.ccc = list(ccc)
self.black = [0] * self.n
self.white = [0] * self.n
# seg は「www[i] + child寄与」を保持する。
# 初期は black/white が 0 なので www と同じ。
self.seg = SegmentTreeInjectable.from_array(
list(www), lambda: 0, lambda a, b: a + b
)
# ccc[i] != ccc[i+1] となる境界 i を管理
self.switches = SortedList(
i for i in range(self.n - 1) if self.ccc[i] != self.ccc[i + 1]
)
def add(self, i, x):
old_c = self.ccc[0]
old_s = self.sum_root()
self.www[i] += x
self.seg.add(i, x)
return old_c, old_s, self.ccc[0], self.sum_root()
def add_child(self, i, x_black, x_white):
old_c = self.ccc[0]
old_s = self.sum_root()
self.black[i] += x_black
self.white[i] += x_white
# seg は active 側 (ccc[i]==1 なら black、0 なら white) のみ反映
delta = x_black if self.ccc[i] == 1 else x_white
if delta:
self.seg.add(i, delta)
return old_c, old_s, self.ccc[0], self.sum_root()
def is_same(self, i):
# [0, i) に境界が存在しなければ ccc[0..i] は全て同じ
return self.switches.bisect_left(i) == 0
def sum_root(self):
if not self.switches:
return self.seg.get_all()
j = self.switches[0]
return self.seg.get_range(0, j + 1)
def sum_point(self, i):
idx = self.switches.bisect_left(i)
l = 0 if idx == 0 else self.switches[idx - 1] + 1
r = self.n - 1 if idx == len(self.switches) else self.switches[idx]
return self.seg.get_range(l, r + 1)
def switch(self, i):
old_c = self.ccc[0]
old_s = self.sum_root()
self.ccc[i] ^= 1
ccc = self.ccc
# 色が変わったので seg[i] を再構築
# ccc[i]==0 → www[i] + white[i]、ccc[i]==1 → www[i] + black[i]
if ccc[i] == 0:
self.seg.update(i, self.www[i] + self.white[i])
else:
self.seg.update(i, self.www[i] + self.black[i])
sw = self.switches
for j in (i - 1, i):
if 0 <= j < self.n - 1:
if ccc[j] != ccc[j + 1]:
if j not in sw:
sw.add(j)
else:
if j in sw:
sw.remove(j)
return old_c, old_s, self.ccc[0], self.sum_root()
class ListPath(Path):
def __init__(self, www, ccc):
self.www = list(www)
self.ccc = list(ccc)
self.n = len(www)
self.black = [0] * self.n
self.white = [0] * self.n
# combined[i] = www[i] + (white[i] if ccc[i]==0 else black[i])
self.combined = list(www)
def _combined_at(self, i):
return self.www[i] + (self.white[i] if self.ccc[i] == 0 else self.black[i])
def add(self, i, x):
old_c = self.ccc[0]
old_s = self.sum_root()
self.www[i] += x
self.combined[i] += x
return old_c, old_s, self.ccc[0], self.sum_root()
def add_child(self, i, x_black, x_white):
old_c = self.ccc[0]
old_s = self.sum_root()
self.black[i] += x_black
self.white[i] += x_white
self.combined[i] += x_black if self.ccc[i] == 1 else x_white
return old_c, old_s, self.ccc[0], self.sum_root()
def is_same(self, i):
ccc = self.ccc
c = ccc[0]
for j in range(1, i + 1):
if ccc[j] != c:
return False
return True
def sum_root(self):
ccc = self.ccc
combined = self.combined
c = ccc[0]
s = 0
for j in range(self.n):
if ccc[j] != c:
break
s += combined[j]
return s
def sum_point(self, i):
ccc = self.ccc
combined = self.combined
c = ccc[i]
s = combined[i]
j = i - 1
while j >= 0 and ccc[j] == c:
s += combined[j]
j -= 1
j = i + 1
while j < self.n and ccc[j] == c:
s += combined[j]
j += 1
return s
def switch(self, i):
old_c = self.ccc[0]
old_s = self.sum_root()
self.ccc[i] ^= 1
if self.ccc[i] == 0:
self.combined[i] = self.www[i] + self.white[i]
else:
self.combined[i] = self.www[i] + self.black[i]
return old_c, old_s, self.ccc[0], self.sum_root()
def solve(n, q, www, ccc, edges, queries):
links = [[] for _ in range(n)]
for a, b in edges:
a -= 1
b -= 1
links[a].append(b)
links[b].append(a)
hld, weights, parents = heavy_light_decomposition(n, links)
SIZE_LIMIT = 100
m = len(hld)
paths: list[Path] = []
pos = [None] * n
for i in range(m):
path = hld[i]
path_www = []
path_ccc = []
for j, v in enumerate(path):
pos[v] = (i, j)
path_www.append(www[v])
path_ccc.append(ccc[v])
if len(path) > SIZE_LIMIT:
path_ins = SegTreePath(path_www, path_ccc)
else:
path_ins = ListPath(path_www, path_ccc)
paths.append(path_ins)
# 初期構築: 葉側のパスから順に sum_root を親の add_child に伝える
for i in range(m - 1, 0, -1):
v0 = hld[i][0]
s = paths[i].sum_root()
c = ccc[v0]
pi, pj = pos[parents[v0]]
if c == 1:
paths[pi].add_child(pj, s, 0)
else:
paths[pi].add_child(pj, 0, s)
# ボトムアップ順なので戻り値は無視してよい (paths[pi] は後で順に処理される)
def propagate(path_idx, diff):
""" Path 操作の戻り値 diff = (old_c, old_s, new_c, new_s) を親方向に伝播 """
while path_idx != 0:
old_c, old_s, new_c, new_s = diff
if old_c == new_c and old_s == new_s:
return
v0 = hld[path_idx][0]
pi, pj = pos[parents[v0]]
x_black = (new_s if new_c == 1 else 0) - (old_s if old_c == 1 else 0)
x_white = (new_s if new_c == 0 else 0) - (old_s if old_c == 0 else 0)
diff = paths[pi].add_child(pj, x_black, x_white)
path_idx = pi
def debug_print():
for path in paths:
print(f' {path.www=} {path.ccc=} {path.combined=}')
buf = []
for query in queries:
if query[0] == 1:
_, v = query
v -= 1
i, j = pos[v]
diff = paths[i].switch(j)
propagate(i, diff)
elif query[0] == 2:
_, v, x = query
v -= 1
i, j = pos[v]
diff = paths[i].add(j, x)
propagate(i, diff)
else:
_, v = query
v -= 1
i, j = pos[v]
while i != 0:
if not paths[i].is_same(j):
ans = paths[i].sum_point(j)
buf.append(ans)
break
pi, pj = pos[parents[hld[i][0]]]
if paths[i].ccc[j] != paths[pi].ccc[pj]:
ans = paths[i].sum_point(j)
buf.append(ans)
break
i = pi
j = pj
else:
ans = paths[i].sum_point(j)
buf.append(ans)
# debug_print()
return buf
n, q = list(map(int, input().split()))
www = list(map(int, input().split()))
ccc = list(map(int, input().split()))
edges = [list(map(int, input().split())) for _ in range(n - 1)]
queries = [list(map(int, input().split())) for _ in range(q)]
ans = solve(n, q, www, ccc, edges, queries)
print('\n'.join(map(str, ans)))