AtCoder Beginner Contest 451 F,G問題メモ
F - Make Bipartite 3
問題文
頂点に $1$ から $N$ までの番号が付いた、$N$ 頂点 $0$ 辺の無向グラフ $G$ があります。
$Q$ 個のクエリが与えられます。$i$ 番目のクエリでは、グラフ $G$ に頂点 $u_i$ と頂点 $v_i$ を結ぶ辺を追加します。
それぞれのクエリを処理した直後のグラフ $G$ において、以下の条件を満たすように $G$ の各頂点を白または黒のどちらか一色で塗ることができるか判定し、可能な場合は黒に塗る頂点の個数としてありうる最小値を求めてください。
制約
$2 \le N \le 2 \times 10^5$
$1 \le Q \le 2 \times 10^5$
$1 \le u_i \lt v_i \le N$
$(u_i, v_i) \ne (u_j, v_j) \ (i \ne j)$
入力される値は全て整数
解法
重み付き Union-Find を使ったけど、よく考えると普通の Union-Find でも set で状態管理できるのか。。。
黒を最小化したいので、はじめ、全ての頂点は白で塗られているとする。
クエリによって発生することは以下である。
同じ連結成分にある、違う色で塗られている頂点同士が結ばれる。
同じ連結成分にある、同じ色で塗られている頂点同士が結ばれる。
違う連結成分にある、違う色で塗られている頂点同士が結ばれる。
違う連結成分にある、同じ色で塗られている頂点同士が結ばれる。
難しいのは、通常の Union-Find では「同じ連結成分にあるか」は分かっても、
それが「同じ色で塗られているか」は分からない点。
そこで、重み付き Union-Find を使えば、同じ連結成分内にある相対的な値の差を把握できるので、
各リーダーの色さえ管理しておけば任意の頂点の色を $O(\alpha(N))$ で取得できる。
さらに各リーダーに、現在の自身の連結成分にある白と黒の個数を持たせておけば、マージ時に適切な方を反転させられる。
Python3
from operator import xor
from typing import TypeVar, Generic, Callable
T = TypeVar('T', int, float, tuple) # データ構造に載せそうな型のうち、Immutable なもの(漏れてたら追加可)
class UnionFindWithPotential(Generic[T]):
"""ポテンシャル付きUnion-Find
以下の2つのクエリを処理する。
* unite(x,y,d): 2要素間の差分に対する制約 (x - y = d) を決める。
* diff(x,y): 2要素間の差を返す。未定義(x と y が連結でない)なら None を返す。
載せられるのはアーベル群に限る(?)。
つまり2項間の「加算演算」「減算演算」が定義でき、また「加算演算が可換」である場合に限る。
Attributes:
table (list[int]): 各頂点の、親の頂点番号を表すList。または親である場合は(-要素数)
values (list[T]): 各頂点の、親との差分を表すList
"""
def __init__(self,
n: int,
init: T,
func: Callable[[T, T], T],
rev_func: Callable[[T, T], T]):
"""コンストラクタ
Args:
n: 要素数
init: 単位元
func: 2項間加算演算 (add)
rev_func: 2項間減算演算 (sub)
"""
self.table = [-1] * n
self.values = [init] * n
self.init = init
self.func = func
self.rev_func = rev_func
def root(self, x: int) -> int:
stack = []
tbl = self.table
vals = self.values
while tbl[x] >= 0:
stack.append(x)
x = tbl[x]
if stack:
val = self.init
while stack:
y = stack.pop()
val = self.func(val, vals[y])
vals[y] = val
tbl[y] = x
return x
def is_same(self, x: int, y: int) -> bool:
return self.root(x) == self.root(y)
def diff(self, x: int, y: int) -> T | None:
""" x と y の差(y - x)を取得。同じグループに属さない場合は None。
"""
if not self.is_same(x, y):
return None
vx = self.values[x]
vy = self.values[y]
return self.rev_func(vy, vx)
def unite(self, x: int, y: int, d: T) -> bool:
""" x と y のグループを、y - x = d となるように統合(順序に注意)
既に x と y が同グループで、矛盾する場合は AssertionError。矛盾しない場合はFalse。
同グループで無く、新たな統合が発生した場合はTrue。
"""
rx = self.root(x)
ry = self.root(y)
vx = self.values[x]
vy = self.values[y]
if rx == ry:
assert self.rev_func(vy, vx) == d
return False
rd = self.rev_func(self.func(vx, d), vy)
self.table[rx] += self.table[ry]
self.table[ry] = rx
self.values[ry] = rd
return True
def get_size(self, x: int) -> int:
return -self.table[self.root(x)]
def solve(n, q, queries):
uft = UnionFindWithPotential(n, 0, xor, xor)
colors = [0] * n
whites = [1] * n
blacks = [0] * n
ans = 0
buf = []
for i in range(q):
if ans == -1:
buf.append(-1)
continue
u, v = queries[i]
u -= 1
v -= 1
check = uft.diff(u, v)
if check is not None:
if check == 0:
ans = -1
buf.append(-1)
else:
buf.append(ans)
continue
ru = uft.root(u)
rv = uft.root(v)
cu = colors[ru] ^ uft.diff(ru, u)
cv = colors[rv] ^ uft.diff(rv, v)
if cu != cv:
uft.unite(u, v, 1)
buf.append(ans)
r1 = uft.root(ru)
r2 = rv if r1 == ru else ru
whites[r1] += whites[r2]
blacks[r1] += blacks[r2]
continue
wu = whites[ru]
bu = blacks[ru]
wv = whites[rv]
bv = blacks[rv]
if bu + wv <= wu + bv:
# v が属する方を反転
uft.unite(u, v, 1)
r = uft.root(ru)
whites[r] = wu + bv
blacks[r] = bu + wv
if r == rv:
colors[r] ^= 1
ans += wv - bv
buf.append(ans)
else:
uft.unite(u, v, 1)
r = uft.root(ru)
whites[r] = bu + wv
blacks[r] = wu + bv
if r == ru:
colors[r] ^= 1
ans += wu - bu
buf.append(ans)
return buf
n, q = list(map(int, input().split()))
queries = [list(map(int, input().split())) for _ in range(q)]
ans = solve(n, q, queries)
print('\n'.join(map(str, ans)))
G - Minimum XOR Walk
問題文
$N$ 頂点 $M$ 辺からなる単純連結無向グラフが与えられます。
頂点には $1$ から $N$ までの、辺には $1$ から $M$ までの番号がそれぞれ付けられており、辺 $i$ は頂点 $U_i$ と頂点 $V_i$ を結ぶ重み $W_i$ の無向辺です。
$2$ 頂点を結ぶウォークの重みを、そのウォークが含む辺の重みの総 XOR とします。
非負整数 $K$ が与えられます。整数の組 $(x,y)$ $(1\leq x\lt y\leq N)$ であって、頂点 $x,y$ を結ぶウォークの重みの最小値が $K$ 以下であるものの個数を求めてください。
制約
$1 \leq T \leq 10^5$
$2 \leq N \leq 2 \times 10^5$
$N-1 \leq M \leq 2 \times 10^5$
$0 \leq K \lt 2^{30}$
$1 \leq U_i \lt V_i \leq N$
$0 \leq W_i \lt 2^{30}$
与えられるグラフは単純かつ連結
入力はすべて整数
$1$ つの入力に含まれるテストケースについて、$N$ の総和は $2 \times 10^5$ 以下
$1$ つの入力に含まれるテストケースについて、$M$ の総和は $2 \times 10^5$ 以下
解法
$O(N^2)$ すら無理な制約でそんなことができるの? と思ってしまう。
でもできるんだねえ。複数のステップを経る面白い問題。
サイクル基底を求めた後、XOR基底を求め、分割統治する。
実現可能なウォークの重み
ウォークの重みは、偶数回通った辺は打ち消し合うので要は「奇数回通った辺の集合」の重みのXORとなる。
「適当に取った全域木での $i→j$ のパス」に
「サイクル基底内のサイクルからいくつか選んで辺をXOR」して作れるものが、
「$i→j$ のウォークで奇数回通る辺の集合」としてあり得るものとなる。
なので、まずグラフのサイクル基底を構築する。
サイクル基底
適当な全域木を構築し、辺を全域木に使われるものと、使われないものに分ける。
使われない辺1本を全域木に加えるとサイクルが1つできる。
使われなかった $k$ 本の辺を(1本ずつ別々に)全域木に加えてできるサイクルの集合を、サイクル基底とする。
このグラフに含まれる全てのサイクルは、サイクル基底に含まれるいくつかのサイクル同士の辺のXORにより表現できる。
ここで、「サイクル基底からのサイクルの選び方」自体は $2^{M-N+1}$ 通りあって多すぎるのだが、
「サイクルの重み同士をいくつかXORすることで作れる値」なら、
XOR基底を使うことで $\log_{2}{A_{\max}} = 30$ 個以下の整数に情報を縮約できる。
つまり、「XOR基底の $30$ 個からいくつか選んでXORして作れる値」と、
「サイクル基底からサイクルをいくつか選んで、その重みをXORして作れる値」が、一致する。
まとめると、以下を前計算すればよい。
ペア数の数え方
頂点 $1$ から各頂点 $v$ に対する、全域木上のパスの重みを求め、これを $d(v)$ とする。
各 $d(v)$ にXOR基底を適用することで「頂点 $1$ から $v$ に行くウォークで、その重みが最も小さくなるもの」が求まる。
これを $e(v)$ とする。
一般に、木において、パスの重みが通過した辺の総XORで計算されるとき、
$i→j$ へのパスの重みは、$d(i) \oplus d(j)$ で求められる。($d(i)$ は頂点 $1→i$ へのパスの重み)
今回の場合も、任意の2点 $u,v$ を結ぶウォークの重みのうち最小のものは $e(u) \oplus e(v)$ で表せる。
略証
$d(u) \oplus d(v)$ は、$u→v$ へのウォークとしてあり得る値の1つである。
$e$ は、$d$ にXOR基底で作れる値の1つをXORしたものである。
一般に、XOR基底で作れる値同士のXORもまたXOR基底で作れる。よって、XOR基底で作れる値 $x,y,z$ を使って、
と表せるので、$e(u) \oplus e(v)$ もまた、$u→v$ へのウォークとしてあり得る値の1つである。
後は、これがあり得る値の中で最小となることを示せばよい。
ところで、XOR基底の要素は全て最大bitが相異なるように構築する。
$e(i)$ は、上位の桁 $k=29,28,...,0$ から順に、
「$d(i)$ で $k$ 桁目が立っていて、$k$ 桁目に対応する基底 $x_k$ があれば、$d(i)←d(i) \oplus x_k$ とする」
ことで求められる。
つまり、各 $e(u)$ や $e(v)$ でbitが立っている箇所は、その bit に対応するXOR基底が無いといえる。
XOR基底が存在するbitは、$e(u),e(v)$ でともに $0$ になっているはずのため、$e(u) \oplus e(v)$ でも $0$ である。
$e(u) \oplus e(v)$ で $1$ のbitは、XOR基底が存在しないので、そこを $0$ にすることはできない。
$e(u) \oplus e(v)$ にXOR基底を適用して作成可能な整数のうち、
$e(u) \oplus e(v)$ 自身こそ、そのXOR基底で $0$ にできるbitは既に $0$ になっている状態なので、これ以上小さくすることはできない。
「$N$ 個の非負整数 $E=(e(1),...,e(N))$ から、2個選んでXORして、$K$ 以下になる組の個数を求めよ」という問題になった。
上の桁 $d=29,28,...,0$ から考えて、
$K$ で $d$ 桁目が立っている場合
$K$ で $d$ 桁目が立っていない場合
こうしていくと、1つの要素は高々 $O(\log{A_{\max}})$ 回しか評価されないため、$O(N \log{A_{\max}})$ でペア数を求めることができる。
Python3
from atcoder.dsu import DSU
def naive(n, m, k, edges):
matrix = [[set() for _ in range(n)] for _ in range(n)]
for i in range(n):
matrix[i][i].add(0)
d = 0
for u, v, w in edges:
u -= 1
v -= 1
matrix[u][v].add(w)
matrix[v][u].add(w)
d = max(d, w.bit_length())
for k_ in range(n):
for i in range(n):
for j in range(n):
update = set()
for x in matrix[i][k_]:
for y in matrix[k_][j]:
update.add(x ^ y)
matrix[i][j].update(update)
for i in range(n):
for j in range(i, n):
can = [f'{c:0{d}b}' for c in sorted(matrix[i][j])]
print(f'{i=} {j=} {can=}')
def solve(n, m, k, edges):
uft = DSU(n)
links = [[] for _ in range(n)]
ex_edges = []
for u, v, w in edges:
u -= 1
v -= 1
if uft.same(u, v):
ex_edges.append((u, v, w))
continue
uft.merge(u, v)
links[u].append((v, w))
links[v].append((u, w))
score_from_0 = [-1] * n
score_from_0[0] = 0
q = [0]
while q:
u = q.pop()
for v, w in links[u]:
if score_from_0[v] != -1:
continue
score_from_0[v] = score_from_0[u] ^ w
q.append(v)
# print(ex_edges)
# print(score_from_0)
BASIS_SIZE = 30
cycle_bases = [0] * BASIS_SIZE
for u, v, w in ex_edges:
x = score_from_0[u] ^ score_from_0[v] ^ w
for i in range(BASIS_SIZE - 1, -1, -1):
if (x >> i) & 1:
if cycle_bases[i] == 0:
cycle_bases[i] = x
break
x ^= cycle_bases[i]
# print(cycle_bases)
min_score = []
for i in range(n):
x = score_from_0[i]
for j in range(BASIS_SIZE - 1, -1, -1):
x = min(x, x ^ cycle_bases[j])
min_score.append(x)
# print(min_score)
def check_same(d, aaa):
if d < 0:
l = len(aaa)
return l * (l - 1) // 2
if len(aaa) <= 10:
res = 0
for i in range(len(aaa)):
for j in range(i + 1, len(aaa)):
if (aaa[i] ^ aaa[j]) <= k:
res += 1
return res
zero = []
one = []
bit = 1 << d
for a in aaa:
if a & bit:
one.append(a)
else:
zero.append(a)
res = 0
lo = len(one)
lz = len(zero)
if k & bit:
res += lo * (lo - 1) // 2
res += lz * (lz - 1) // 2
if lo > 0 and lz > 0:
res += check_diff(d - 1, one, zero)
else:
if lo > 0:
res += check_same(d - 1, one)
if lz > 0:
res += check_same(d - 1, zero)
return res
def check_diff(d, aaa, bbb):
if d < 0:
return len(aaa) * len(bbb)
if len(aaa) * len(bbb) < 100:
res = 0
for a in aaa:
for b in bbb:
if (a ^ b) <= k:
res += 1
return res
a_zero = []
a_one = []
b_zero = []
b_one = []
bit = 1 << d
for a in aaa:
if a & bit:
a_one.append(a)
else:
a_zero.append(a)
for b in bbb:
if b & bit:
b_one.append(b)
else:
b_zero.append(b)
res = 0
if k & bit:
res += len(a_one) * len(b_one)
res += len(a_zero) * len(b_zero)
if len(a_one) > 0 and len(b_zero):
res += check_diff(d - 1, a_one, b_zero)
if len(a_zero) > 0 and len(b_one):
res += check_diff(d - 1, a_zero, b_one)
else:
if len(a_one) > 0 and len(b_one):
res += check_diff(d - 1, a_one, b_one)
if len(a_zero) > 0 and len(b_zero):
res += check_diff(d - 1, a_zero, b_zero)
return res
ans = check_same(29, min_score)
return ans
t = int(input())
for _ in range(t):
n, m, k = list(map(int, input().split()))
edges = [list(map(int, input().split())) for _ in range(m)]
ans = solve(n, m, k, edges)
print(ans)