赤黒木
概要
平衡探索二分木の一種で、以下を満たすように木を保つ。
- 各ノードを赤と黒で(仮想的に)塗り分ける
- 根は黒
- 赤の子は黒(赤は2つ以上連続しない)
- 根からそれぞれの葉までの経路上にある黒の個数は等しい
こんなことを決めて何がいいのかと、
- これが満たされている木では、根から最短経路(黒黒黒……)と最長経路(黒赤黒赤……)の差は2倍を超えないので、ほどよい平衡が保たれることが保証される
- データを追加したり削除したりするとき、親子で繋がった直近3~4ノードの色のパターンを見て塗り替えることで状態を維持できる
- その中だけで状態を解消できないときは親に遡ることはあるが、再帰的に実装できる
ただ、パターン分けは結構種類があってややこしい。 詳細は参考の1つめのサイトが丁寧に場合分けを説明されている。
- 参考
実装
再帰関数は極力無くした実装(1箇所だけ、多重再帰にはならない箇所で使用)。どうしても配列アクセスは多くなるので、Pythonだと遅い。PyPyなら……速いとは言えないが、まぁまぁ。
| 関数 | 返値 | 概要 |
|---|---|---|
| insert(x) | None | $x$ を追加(実装によっては重複する値を許さないこともできる。以下の実装では許す) |
| delete(x) | None | $x$ を削除(無ければ何もしない) |
| upper_bound(x) | $y,i$ | $x \lt y$ を満たす最小の $y$ と、その小さい方からのindex $i$ (0-indexed)を取得 |
| lower_bound(x) | $y,i$ | $x \le y$ を満たす最小の $y$ と、その小さい方からのindex $i$ (0-indexed)を取得 |
| get_by_index(i) | $x$ | 小さい方から $i$ 番目の値を取得(0-indexed) |
| debug_print() | None | 木を90°倒して左が根、上が小さい方となるように、“色,値,部分木以下の要素数”を描画 |
- Verification: AtCoder Beginner Contest 140 E - Second Sum(PyPy3, 949ms)
"""
赤黒木
"""
class RedBlackTree:
# For performance, nodes are implemented as lists.
# [left, right, data, color, count]
# color: 1:RED 0:BLACK
#
# It is useful to set the left/right index to 0/1 for boolean access.
# The end nodes are BLACK and of which value is EOT (End Of Tree).
#
# Usage example:
# rbt = RedBlackTree()
# rbt.insert(x)
# rbt.delete(x)
# rbt.upper_bound(x)
# rbt.lower_bound(x)
# rbt.get_by_index(i)
# rbt.debug_print()
def __init__(self, EOT=-1):
self.EOT = EOT
self.root = self._leaf()
def _leaf(self):
return [None, None, self.EOT, 0, 0]
def _rotate(self, node, r):
child = node[r ^ 1]
node[r ^ 1] = child[r]
child[r] = node
child[3] = node[3]
node[3] = 1
child_count = child[4]
child[4] = node[4]
node[4] -= child_count - node[r ^ 1][4]
return child
def insert(self, x):
stack = []
node = self.root
while node[2] != self.EOT:
# Variation: If the same value is not allowed, return when found.
to_right = x >= node[2]
stack.append((node, to_right))
node = node[to_right]
# Insert values to the end node
node[0] = self._leaf()
node[1] = self._leaf()
node[2] = x
node[3] = 1
node[4] = 1
# Increase count
for parent, _ in stack:
parent[4] += 1
# Validate tree and rotate if needed.
while stack:
parent, r = stack.pop()
if parent[3] == 1:
node = parent
continue
parent[r] = node
node, flag = self._balance_insert(parent, r)
if stack and flag == True:
parent, r = stack.pop()
parent[r] = node
break
else:
# If the root might have changed, update its color.
self.root = node
self.root[3] = 0
def _balance_insert(self, node, r):
flag = True
if node[r][r ^ 1][3] == 1:
node[r] = self._rotate(node[r], r)
if node[r][r][3] == 1:
if node[r ^ 1][3] == 1:
node[3] = 1
node[r ^ 1][3] = node[r][3] = 0
flag = False
else:
node = self._rotate(node, r ^ 1)
return node, flag
def delete(self, x):
node = self.root
stack = []
while node[2] != x:
r = node[2] < x
stack.append((node, r))
node = node[r]
# Not Found
if node[2] == self.EOT:
return
# Node has 2 children: swap min node in right tree and delete min node
if node[0][2] != self.EOT and node[1][2] != self.EOT:
stack.append((node, 1))
min_node = self.get_min(node[1], stack)
node[2] = min_node[2]
node = min_node
# Delete node is root
if not stack:
if node[0][2] == self.EOT:
self.root = node[1]
else:
self.root = node[0]
self.root[3] = 0
return
# Decrease count
for parent, _ in stack:
parent[4] -= 1
# Node has 0/1 child
parent, r = stack[-1]
if node[0][2] == self.EOT:
parent[r] = node[1]
node[1][3] = 0
# Balance is only needed if both children are EOT and self color is black.
if node[1][2] != self.EOT or node[3] == 1:
return
elif node[1][2] == self.EOT:
parent[r] = node[0]
node[0][3] = 0
return
# Validate tree and rotate if needed.
while stack:
parent, r = stack.pop()
node, flag = self._balance_delete(parent, r)
if stack and flag == True:
parent, r = stack.pop()
parent[r] = node
break
else:
# If the root might have changed, update its color.
self.root = node
self.root[3] = 0
def get_min(self, node, stack):
while node[0][2] != self.EOT:
stack.append((node, 0))
node = node[0]
return node
def _balance_delete(self, node, r):
if node[r ^ 1][r][3] == 0 and node[r ^ 1][r ^ 1][3] == 0:
if node[r ^ 1][3] == 0:
node[r ^ 1][3] = 1
if node[3] == 0:
return node, False
node[3] = 0
else:
node = self._rotate(node, r)
node[r], _ = self._balance_delete(node[r], r)
else:
if node[r ^ 1][r][3] == 1:
node[r ^ 1] = self._rotate(node[r ^ 1], r ^ 1)
node = self._rotate(node, r)
node[r][3] = 0
node[r ^ 1][3] = 0
return node, True
def upper_bound(self, x):
"""
:return Smallest y satisfying x < y and its leftmost index.
If not exists, (EOT, length) will be returned.
"""
node = self.root
y = self.EOT
i = node[4]
j = 0
while node[2] != self.EOT:
if x < node[2]:
y = node[2]
i = j + node[0][4]
node = node[0]
else:
j += node[0][4] + 1
node = node[1]
return y, i
def lower_bound(self, x):
"""
:return Smallest y satisfying x <= y and its leftmost index.
If not exists, (EOT, length) will be returned.
"""
node = self.root
y = self.EOT
i = node[4]
j = 0
while node[2] != self.EOT:
if x <= node[2]:
y = node[2]
i = j + node[0][4]
node = node[0]
else:
j += node[0][4] + 1
node = node[1]
return y, i
def get_by_index(self, i):
"""
:return (0-indexed) i-th smallest item.
If i is greater than length, EOT will be returned.
"""
node = self.root
if node[4] <= i:
return self.EOT
j = i
while node[2] != self.EOT:
left_count = node[0][4]
if left_count == j:
return node[2]
elif left_count > j:
node = node[0]
else:
j -= left_count + 1
node = node[1]
def debug_print(self):
self._debug_print(self.root, 0)
def _debug_print(self, node, depth):
if node[2] != self.EOT:
self._debug_print(node[0], depth + 1)
print(' ' * depth, 'BR'[node[3]], node[2], node[4])
self._debug_print(node[1], depth + 1)
def check():
rbt = RedBlackTree()
for x in [0, 1, 2, 3, 4, 5, 6, 3, 3, 3, 5, 5, 5, 0, 0, 0]:
print('------ INSERT', x, '------')
rbt.insert(x)
rbt.debug_print()
for x in [0, 1, 2, 3, 4, 5, 6, 7]:
print('------ UPPER', x, '=', rbt.upper_bound(x))
print('------ LOWER', x, '=', rbt.lower_bound(x))
for x in [5, 0, 3, 3, 0, 5]:
print('------ DELETE', x, '------')
rbt.delete(x)
rbt.debug_print()
for i in range(11):
print('------ INDEX', i, '=', rbt.get_by_index(i))
if __name__ == '__main__':
check()

