Union-Find木
という前提で、
unite(x, y)
: 要素 $x$ の属するグループと、$y$ の属するグループを統合する
is_same(x, y)
: 要素 $x$ と $y$ が同じグループに属するか調べる
という2つのクエリを効率的に処理するのが、Union-Find木
Disjoint Set Union、略してDSUともいう。
アルゴリズム
各グループの要素から、“代表”を1つ決める
代表以外の要素は、“親”となる要素を1つずつ持つ
1 4 8 ←1,4,8が代表
↗↑ ↑
2 5 6 ←2や5の親は1, 3の親は2
↑ ↗ ↖
3 7 9
これをもとに、先ほどの2つの操作を言い換えると、以下のようになる。
unite(x, y)
: 一方の代表を、もう一方の代表につなぐ(親子にする)
is_same(x, y)
: $x$ と $y$ から親を辿って同じ代表に行き着くか調べる
また、そのために以下の操作も持つ。
これを根を見つける(Find)と捉え、2つの操作を合わせてUnion-Findとなる。
高速化の工夫
◎はやい ×おそい
1 1
↗↑↖ ↑
2 3 4 2
↑
3
↑
4
経路圧縮
ランク
この2つをuniteする場合| こうすると | 低い方に高い方をつなげると
Rank:2 Rank:4 | Rank:4のまま | Rank:5になっちゃう
1 5 | 5 | 1
↗↑↖ ↑ | ↑↖ | ↗↗↑↖
2 3 4 6 | 1 6 |23 4 5
↑ | ↗↗↑ ↑ | ↑
7 | 23 4 7 | 6
↑ | ↑ | ↑
8 | 8 | 7
| | ↑
| | 8
計算量
上記の2つの改善を行って、unite, is_same, root
1回にかかる計算量はいずれも $O(\alpha(N))$ となる(らしい)。
これはいずれもならし計算量で、何回もやると平均がこの値に近づいていく。
ただし $\alpha(N)$ はアッカーマン関数 $\rm{Ack}(N,N)$ の逆関数で、$N$ が少し大きくなると急速に小さくなるので、ほとんど定数と考えてよい。
実装
2020/05 rootを非再帰に変更(特にPyPyで高速化)
unite()
のd1,d2は実際の値を正負逆転させているので、大小関係に注意。
# Python3
class UnionFind:
def __init__(self, n):
# 負 : 根であることを示す。絶対値はランクを示す
# 非負: 根でないことを示す。値は親を示す
self.table = [-1] * n
def root(self, x):
stack = []
tbl = self.table
while tbl[x] >= 0:
stack.append(x)
x = tbl[x]
for y in stack:
tbl[y] = x
return x
def is_same(self, x, y):
return self.root(x) == self.root(y)
def unite(self, x, y):
r1 = self.root(x)
r2 = self.root(y)
if r1 == r2:
return
# ランクの取得
d1 = self.table[r1]
d2 = self.table[r2]
if d1 <= d2:
self.table[r2] = r1
if d1 == d2:
self.table[r1] -= 1
else:
self.table[r1] = r2
if __name__ == '__main__':
ins = UnionFind(4)
ins.unite(0, 1)
ins.unite(2, 3)
print(ins.is_same(0, 3)) # False
ins.unite(1, 2)
print(ins.is_same(0, 3)) # True
亜種
ランクを要素数とする
ランクは木の深さでなく、要素数で考えることもある。
例: 要素数100,最大深さ2のグループAと、要素数5,最大深さ4のグループB
AにBを繋ぐとBの5個の要素の深さが1ずつ増え、最大深さは5
BにAを繋ぐとAの100個の要素の深さが1ずつ増え、最大深さは4
局所的に最大深さが深くなろうが、100個の要素の深さが1ずつ増える方が効率悪くない?という考え方か
A.V.Aho, John E. Hopcroft, Jeffrey D. Ulman, 『データ構造とアルゴリズム』, 培風館, 1987
実際、どちらを選んでも計算量は変わらない(経路圧縮を行う場合で $O(\alpha(N))$、行わない場合で $O(\log{N})$)し、速度は体感できるほど変わらない。
木の深さってあまり使い道が無いが、「$x$ が含まれるグループの要素数」は任意の時点で取得できると嬉しい場合はある。
よって、ランクは要素数としておいた方が、何かと便利かも知れない。
実装
unite()
の中身が少し変化する。また、get_size()
を追加。
Python3
class UnionFind:
def __init__(self, n):
self.table = [-1] * n
def root(self, x):
stack = []
tbl = self.table
while tbl[x] >= 0:
stack.append(x)
x = tbl[x]
for y in stack:
tbl[y] = x
return x
def is_same(self, x, y):
return self.root(x) == self.root(y)
def unite(self, x, y):
r1 = self.root(x)
r2 = self.root(y)
if r1 == r2:
return
d1 = self.table[r1]
d2 = self.table[r2]
if d1 <= d2:
self.table[r2] = r1
self.table[r1] += d2
else:
self.table[r1] = r2
self.table[r2] += d1
def get_size(self, x):
return -self.table[self.root(x)]
グループの要素も保持
上記の例では、グループの要素数はわかっても、具体的に現時点でのグループにはどの要素があるのか、というのはわからない。
グループの要素も管理したい場合、tableとは別にもう1つグループを管理する配列も用意して、実際に統合していく。
unite()
部分のならし計算量は、$O(\alpha(N))$ から $O(\alpha(N) \log{N})$ になる。
実装
Python3
class UnionFindWithGrouping:
def __init__(self, n):
self.table = [-1] * n
self.group = [{i} for i in range(n)]
def root(self, x):
stack = []
tbl = self.table
while tbl[x] >= 0:
stack.append(x)
x = tbl[x]
for y in stack:
tbl[y] = x
return x
def is_same(self, x, y):
return self.root(x) == self.root(y)
def unite(self, x, y):
r1 = self.root(x)
r2 = self.root(y)
if r1 == r2:
return
d1 = self.table[r1]
d2 = self.table[r2]
if d1 <= d2:
self.table[r2] = r1
self.table[r1] += d2
self.group[r1].update(self.group[r2])
else:
self.table[r1] = r2
self.table[r2] += d1
self.group[r2].update(self.group[r1])
def get_size(self, x):
return -self.table[self.root(x)]
def get_group(self, x):
return self.group[self.root(x)]
Rollback付きUnion-Find
Undo付きUnion-Findなどとも呼ばれる。履歴を持ち、直前で結合したグループを戻せる。
root(x)
: $x$ の根を取得
unite(x, y)
: $x$ と $y$ の所属グループを結合
is_same(x, y)
: $x$ と $y$ が同じグループか判定
rollback()
: 直前のuniteを取り消す(繰り返し実行した場合は更にその前の操作も順に取り消す)
通常のUnion-Findに、rollbackの機能が加わったもの。
計算量は $O(\alpha(N))$ から 若干悪化して $O(\log{N})$ となるが、それでも十分高速。
Union-Findにおいてtableの変更が行われるのは、
root(x)
: 代表を取得するときの経路圧縮
unite(x,y)
: 結合
このうち、経路圧縮をおこなうのは以下の理由からやや難しい。
一方、結合時は変更が確実に「unite()
した時のみ」であり(これは「Undo」という操作の直感にも従う)、変更箇所も少ない。
「変更した要素 $x,y$、変更前の値 $table[x],table[y]$」の4つを、履歴を管理するためのスタックに積んでおけば復元できる。
実装
細かな実装の違いとして、uniteしようとした結果、既に同グループだった場合、履歴に残すかどうかがあるが、
有効だったかどうかを呼び出し側が管理しないで済むので、基本的には残した方がよさそう。
Python3
class RollbackUnionFind:
"""
履歴を持ち、直前の操作を(順に)取り消せるUnionFind
"""
def __init__(self, n: int):
self.table = [-1] * n
self.history = []
# (継続リーダー番号1, 被統合リーダー番号2, 1統合前table値, 2統合前table値)
# 無効なuniteの場合、被統合リーダー番号を"-1"とする
def root(self, x: int):
tbl = self.table
while tbl[x] >= 0:
x = tbl[x]
return x
def is_same(self, x: int, y: int):
return self.root(x) == self.root(y)
def unite(self, x: int, y: int):
r1 = self.root(x)
r2 = self.root(y)
if r1 == r2:
self.history.append((r1, -1, -1, -1))
return False
d1 = self.table[r1]
d2 = self.table[r2]
if d1 <= d2:
self.table[r2] = r1
self.table[r1] += d2
self.history.append((r1, r2, d1, d2))
else:
self.table[r1] = r2
self.table[r2] += d1
self.history.append((r2, r1, d2, d1))
return True
def rollback(self):
if not self.history:
return False
r1, r2, d1, d2 = self.history.pop()
if r2 == -1:
return False
self.table[r1] = d1
self.table[r2] = d2
return True
def get_size(self, x: int):
return -self.table[self.root(x)]
部分永続Union-Find
過去の状態に遡って is_same
できる。英語では Partially Persistent Union-Find。
root(x, t)
: 時刻 $t$ における $x$ の根を取得
unite(x, y)
: $x$ と $y$ の所属グループを結合し、内部時刻を1進める
is_same(x, y, t)
: 時刻 $t$ において、$x$ と $y$ が同じグループか判定
更新は最新の状態に対してのみ行える。
(過去の状態からも更新可能にしたものは、「完全永続」という)
基本的に、クエリが先読みできるならis_same()
を適切なタイミングに割り込ませることで
通常のUnion-Findで対応可能なので、有用な場面は限られる?
こちらもRollback Union-Findと同様、経路圧縮は難しいので、計算量はいずれの操作も $O(\log{N})$ となる。
実装
通常のUnionFindではtableは現在値のみを持つところ、「(更新時刻, 更新値)のタプル」を基本単位として持つ。
機能(※)の有無により、必要な実装が若干異なる。
uniteとis_sameしかできない実装を①、get_sizeなども可能な実装を②とする。
②はわずかに必要メモリが増加する。(追加で $O(クエリ数)$ くらい。まぁ微差)
通常の初期化: table = [-1, -1, -1, ...]
値のリスト
部分永続①の初期化: table = [(0, -1), (0, -1), (0, -1), ...]
(更新時刻, 更新値)のリスト
部分永続②の初期化: table = [[(0, -1)], [(0, -1)], [(0, -1)], ...]
(更新時刻, 更新値)のリストのリスト
(他に記録したい値があるならそれもタプルに含める)
部分永続①の更新: table[x] = (t, value) 上書き
部分永続②の更新: table[x].append( (t, value) ) リストにappendし、過去の値を残す
部分永続が通常のUnionFindと異なるのは、主に根を特定する処理 root(x, t)
。
少しテクニカルな考え方によって、「最終更新時刻と更新値」のみから根が特定できる。
なので、過去の履歴がなくても、最終の更新時刻と更新値があればよい。
また、部分永続②において get_size(x, t)
を求める場合は、以下でできる。(過去の履歴はここで必要になる)
Python3
部分永続②の方。未検証。
from bisect import bisect
class PartialPersistentUnionFind:
"""
時刻 t は、「t番目のuniteクエリ完了時点」を意味する。初期状態は t=0。
root(x [, t]): 時刻 t の x の根(t 省略時は最新)
find(x, y [, t]): 時刻 t に x と y が同一グループか判定
unite(x, y): x と y を結合
get_size(x [, t]): 時刻 t の x の属するグループの要素数
"""
def __init__(self, n: int):
self.t = 0
self.timetable = [[0] for _ in range(n)]
self.valuetable = [[-1] for _ in range(n)]
def _getvalue(self, x: int, t: int):
i = bisect(self.timetable[x], t) - 1
return self.valuetable[x][i]
def root(self, x: int, t: int = -1):
# 更新が行われるのは、その時点の根に対してだけなので、
# x の最終更新時刻 > t なら、t 時点では確実に x は根。
# そうでない場合は最終更新時の値の正負で、通常のUnionFindと同様、根か、親を持っているか判別。
if t == -1:
t = self.t
while self.timetable[x][-1] <= t and self.valuetable[x][-1] >= 0:
x = self.valuetable[x][-1]
return x
def is_same(self, x: int, y: int, t: int = -1):
return self.root(x, t) == self.root(y, t)
def unite(self, x: int, y: int):
self.t += 1
r1 = self.root(x, self.t)
r2 = self.root(y, self.t)
if r1 == r2:
return False
self.timetable[r1].append(self.t)
self.timetable[r2].append(self.t)
d1 = self.valuetable[r1][-1]
d2 = self.valuetable[r2][-1]
if d1 <= d2:
self.valuetable[r1].append(d1 + d2)
self.valuetable[r2].append(r1)
else:
self.valuetable[r1].append(r2)
self.valuetable[r2].append(d1 + d2)
return True
def get_size(self, x: int, t: int = -1):
if t == -1:
t = self.t
r = self.root(x, t)
return -self._getvalue(r, t)
完全永続Union-Find
全永続~~、あるいは単に永続~~ともいう。英語では Fully Persistent Union-Find。
root(x, t)
: バージョン $t$ の $x$ の根を取得
unite(x, y, t)
: バージョン $t$ の状態から、$x$ と $y$ を統合し、最新バージョンとする
is_same(x, y, t)
: バージョン $t$ で、$x$ と $y$ が同じグループか判定
部分永続とは異なり、unite
も $t$ を引数に取るようになりどこからでも更新が可能になった。
($t$ は、時刻というよりバージョン、
あるいはもっと直接的に「何回目のunite
操作後の状態を示すか」と捉えた方がわかりやすいかもしれない。
$t$ は、変更元の $t$ に依らずに回数でナンバリングされていくので、
$t-1$ や $t+1$ との間に必ずしも関連性があるわけではないが、「時刻」だとあるかのように錯覚しかねないので)
部分永続Union-Findは、根しか更新対象にならないというUnion-Find特有の性質を上手く使っていたが、
完全永続Union-Findは、Union-Find特有と言うよりは、より汎用的な「永続配列」を使って実装する。
元のUnion-Findからの主な変更点は、
「状態を保持する配列 table を、永続配列に置き換える」ことくらいなので、
あまり完全永続Union-Find独自の特長として語ることはないか。
一応、Union-Findは一度のuniteでtable
配列を2箇所、書き換えるが、
永続配列を「どこか1箇所書き換えるたびにバージョンが1進む」ようになっていると、
Union-Findと永続配列でバージョンの進みがずれてきてしまうので、
「同時に複数箇所書き換えられる」ようにした永続配列を使えれば少し混乱しにくい。
実装
未検証。
Python3
from typing import TypeVar, Generic, List, Optional, Sequence
T = TypeVar('T')
class PersistentArrayNode(Generic[T]):
def __init__(self, m: int, t: int, value: Optional[T] = None):
self.children: List[Optional['PersistentArrayNode']] = [None] * m
self.created: int = t
self.value: Optional[T] = value
def copy(self, t: int):
res = PersistentArrayNode(0, t, self.value)
res.children = self.children.copy() # shallow copy
return res
class PersistentArray(Generic[T]):
"""
永続配列
"""
def __init__(self, array: Sequence[T]):
n = len(array)
log_n = (n - 1).bit_length()
root = PersistentArrayNode(log_n, 0, array[0])
self.n = n
self.m = log_n # M分木を構築
self.roots = [root]
# 初期化
q = [(0, 1, root)] # (そのノードが示すi, 隣り合う子の差分, ノードインスタンス)
m = self.m
while q:
i, d, node = q.pop()
nd = d * m
# children[0] だけはleading-zero問題があるので、別処理
# (現在の i=xxxx(M進数) としたとき、0xxxx に進むので、そのノード自身の値を持たない)
ci = i + nd
if ci < n:
child = PersistentArrayNode(log_n, 0)
node.children[0] = child
q.append((i, nd, child))
for j in range(1, m):
ci = i + d * j
if ci < n:
child = PersistentArrayNode(log_n, 0, array[ci])
node.children[j] = child
q.append((ci, nd, child))
else:
break
def get_node(self, i: int, t: int) -> PersistentArrayNode:
# a[i] を示すノードを取得
# https://37zigen.com/persistent-array/
# (データの持ち方: 2番目の方法, M進数にして下位の桁から見る)
node = self.roots[t]
m = self.m
while i > 0:
node = node.children[i % m]
i //= m
return node
def get_value(self, i: int, t: int) -> T:
return self.get_node(i, t).value
def update(self, i: int, x: T, t: int) -> int:
t_new = len(self.roots)
node = self.roots[t].copy(t_new)
self.roots.append(node)
m = self.m
while i > 0:
child = node.children[i % m].copy(t_new)
node.children[i % m] = child
node = child
i //= m
node.value = x
return t_new
def update_newest(self, i: int, x: T):
"""
最新バージョンに対し、新規バージョンを作らないで更新する。
複数箇所を一括に更新することがあり、一括更新の途中状態は不要という場合、
最初だけ update し、2番目以降の更新をこちらにすることで、tの進み方が呼び出し元とずれることを防ぐ。
"""
# ノードに自身か作成された時刻を持たせ、作成時刻が最新バージョンより古ければコピー、同じなら使い回す
t_new = len(self.roots) - 1
node = self.roots[-1]
m = self.m
while i > 0:
child = node.children[i % m]
if child.created < t_new:
child = child.copy(t_new)
node.children[i % m] = child
node = child
i //= m
node.value = x
def debug_print(self):
for t in range(len(self.roots)):
res = []
for i in range(self.n):
res.append(self.get_value(i, t))
print(t, res)
class PersistentUnionFind(Generic[T]):
def __init__(self, n):
self.table = PersistentArray([-1] * n)
def root(self, x, t):
tbl = self.table
p = tbl.get_value(x, t)
while p >= 0:
x = p
p = tbl.get_value(x, t)
return x
def find(self, x, y, t):
return self.root(x, t) == self.root(y, t)
def unite(self, x, y, t):
r1 = self.root(x, t)
r2 = self.root(y, t)
if r1 == r2:
return False
d1 = self.table.get_value(r1, t)
d2 = self.table.get_value(r2, t)
if d1 <= d2:
self.table.update(r2, r1, t)
self.table.update_newest(r1, d1 + d2)
else:
self.table.update(r1, r2, t)
self.table.update_newest(r2, d1 + d2)
return True
def get_size(self, x, t):
return -self.table.get_value(self.root(x, t), t)
ポテンシャル付きUnion-Find
重み付きUnion-Findともいう。
頂点間の「距離」も管理できるUnion-Find。
「各頂点 $i$ に何らかの値 $W_i$ が決まっているが直接は観測できなくて、頂点間の差分だけがわかる」というときに、
矛盾するかしないか、しない場合に指定した2頂点間の差分はいくつになるか、を求める。
① ④
-2↗↖5 ↑3
② ③ ⑥
6↑ ↑1
⑤ ⑦
② - ① = -2
③ - ① = 5
⑤ - ① = 4
③ - ② = 7
⑦ - ⑤ = 2
⑥ - ② = 不定(同じ連結成分では無いため)
以下のような特徴を持ったデータ構造である。
root(x)
: $x$ の根を取得
unite(x, y, d)
: $x$ と $y$ の所属グループを、$W_y-W_x=d$ となるように統合(あるいは矛盾を検知)
is_same(x, y)
: $x$ と $y$ が同じグループか判定
diff(x, y)
: $x$ と $y$ が同じグループか判定、同じなら値の差分($W_y-W_x$)を取得
通常のUnion-Findに加え、親を $0$ とした時の各頂点の値を記録していく。
根(代表)まで辿っていけば、「根を $0$ とした時の自身の値」がわかり、同一グループ内なら比較ができるようになる。
値の更新に気をつければ経路圧縮もできるので、各計算量は $O(\alpha(N))$ となる。
実装
Python3
from typing import Callable, Generic, TypeVar, List
T = TypeVar('T')
class UnionFindWithPotential(Generic[T]):
def __init__(self,
n: int,
init: Callable[[], T],
func: Callable[[T, T], T],
rev_func: Callable[[T, T], T]):
"""
:param n:
:param init: 単位元の生成関数
:param func: 2項間加算関数(add)
:param rev_func: 逆関数(sub)
"""
self.table: List[int] = [-1] * n
self.values: List[T] = [init() for _ in range(n)]
self.init: Callable[[], T] = init
self.func: Callable[[T, T], T] = func
self.rev_func: Callable[[T, T], T] = rev_func
def root(self, x: int) -> int:
stack = []
tbl = self.table
vals = self.values
while tbl[x] >= 0:
stack.append(x)
x = tbl[x]
if stack:
val = self.init()
while stack:
y = stack.pop()
val = self.func(val, vals[y])
vals[y] = val
tbl[y] = x
return x
def is_same(self, x: int, y: int) -> bool:
return self.root(x) == self.root(y)
def diff(self, x: int, y: int) -> T:
"""
x と y の差(y - x)を取得。同じグループに属さない場合は None。
"""
if not self.is_same(x, y):
return None
vx = self.values[x]
vy = self.values[y]
return self.rev_func(vy, vx)
def unite(self, x: int, y: int, d: T) -> bool:
"""
x と y のグループを、y - x = d となるように統合。
既に x と y が同グループで、矛盾する場合は AssertionError。矛盾しない場合はFalse。
同グループで無く、新たな統合が発生した場合はTrue。
"""
rx = self.root(x)
ry = self.root(y)
vx = self.values[x]
vy = self.values[y]
if rx == ry:
assert self.rev_func(vy, vx) == d
return False
rd = self.rev_func(self.func(vx, d), vy)
self.table[rx] += self.table[ry]
self.table[ry] = rx
self.values[ry] = rd
return True
def get_size(self, x: int) -> int:
return -self.table[self.root(x)]
Retroactive Union-Find