赤黒木
概要
平衡探索二分木の一種で、以下を満たすように木を保つ。
- 各ノードを赤と黒で(仮想的に)塗り分ける
- 根は黒
- 赤の子は黒(赤は2つ以上連続しない)
- 根からそれぞれの葉までの経路上にある黒の個数は等しい
こんなことを決めて何がいいのかと、
- これが満たされている木では、根から最短経路(黒黒黒……)と最長経路(黒赤黒赤……)の差は2倍を超えないので、ほどよい平衡が保たれることが保証される
- データを追加したり削除したりするとき、親子で繋がった直近3~4ノードの色のパターンを見て塗り替えることで状態を維持できる
- その中だけで状態を解消できないときは親に遡ることはあるが、再帰的に実装できる
ただ、パターン分けは結構種類があってややこしい。 詳細は参考の1つめのサイトが丁寧に場合分けを説明されている。
- 参考
実装
再帰関数は極力無くした実装(1箇所だけ、多重再帰にはならない箇所で使用)。どうしても配列アクセスは多くなるので、Pythonだと遅い。PyPyなら……速いとは言えないが、まぁまぁ。
関数 | 返値 | 概要 |
---|---|---|
insert(x) | None | $x$ を追加(実装によっては重複する値を許さないこともできる。以下の実装では許す) |
delete(x) | None | $x$ を削除(無ければ何もしない) |
upper_bound(x) | $y,i$ | $x \lt y$ を満たす最小の $y$ と、その小さい方からのindex $i$ (0-indexed)を取得 |
lower_bound(x) | $y,i$ | $x \le y$ を満たす最小の $y$ と、その小さい方からのindex $i$ (0-indexed)を取得 |
get_by_index(i) | $x$ | 小さい方から $i$ 番目の値を取得(0-indexed) |
debug_print() | None | 木を90°倒して左が根、上が小さい方となるように、“色,値,部分木以下の要素数”を描画 |
- Verification: AtCoder Beginner Contest 140 E - Second Sum(PyPy3, 949ms)
""" 赤黒木 """ class RedBlackTree: # For performance, nodes are implemented as lists. # [left, right, data, color, count] # color: 1:RED 0:BLACK # # It is useful to set the left/right index to 0/1 for boolean access. # The end nodes are BLACK and of which value is EOT (End Of Tree). # # Usage example: # rbt = RedBlackTree() # rbt.insert(x) # rbt.delete(x) # rbt.upper_bound(x) # rbt.lower_bound(x) # rbt.get_by_index(i) # rbt.debug_print() def __init__(self, EOT=-1): self.EOT = EOT self.root = self._leaf() def _leaf(self): return [None, None, self.EOT, 0, 0] def _rotate(self, node, r): child = node[r ^ 1] node[r ^ 1] = child[r] child[r] = node child[3] = node[3] node[3] = 1 child_count = child[4] child[4] = node[4] node[4] -= child_count - node[r ^ 1][4] return child def insert(self, x): stack = [] node = self.root while node[2] != self.EOT: # Variation: If the same value is not allowed, return when found. to_right = x >= node[2] stack.append((node, to_right)) node = node[to_right] # Insert values to the end node node[0] = self._leaf() node[1] = self._leaf() node[2] = x node[3] = 1 node[4] = 1 # Increase count for parent, _ in stack: parent[4] += 1 # Validate tree and rotate if needed. while stack: parent, r = stack.pop() if parent[3] == 1: node = parent continue parent[r] = node node, flag = self._balance_insert(parent, r) if stack and flag == True: parent, r = stack.pop() parent[r] = node break else: # If the root might have changed, update its color. self.root = node self.root[3] = 0 def _balance_insert(self, node, r): flag = True if node[r][r ^ 1][3] == 1: node[r] = self._rotate(node[r], r) if node[r][r][3] == 1: if node[r ^ 1][3] == 1: node[3] = 1 node[r ^ 1][3] = node[r][3] = 0 flag = False else: node = self._rotate(node, r ^ 1) return node, flag def delete(self, x): node = self.root stack = [] while node[2] != x: r = node[2] < x stack.append((node, r)) node = node[r] # Not Found if node[2] == self.EOT: return # Node has 2 children: swap min node in right tree and delete min node if node[0][2] != self.EOT and node[1][2] != self.EOT: stack.append((node, 1)) min_node = self.get_min(node[1], stack) node[2] = min_node[2] node = min_node # Delete node is root if not stack: if node[0][2] == self.EOT: self.root = node[1] else: self.root = node[0] self.root[3] = 0 return # Decrease count for parent, _ in stack: parent[4] -= 1 # Node has 0/1 child parent, r = stack[-1] if node[0][2] == self.EOT: parent[r] = node[1] node[1][3] = 0 # Balance is only needed if both children are EOT and self color is black. if node[1][2] != self.EOT or node[3] == 1: return elif node[1][2] == self.EOT: parent[r] = node[0] node[0][3] = 0 return # Validate tree and rotate if needed. while stack: parent, r = stack.pop() node, flag = self._balance_delete(parent, r) if stack and flag == True: parent, r = stack.pop() parent[r] = node break else: # If the root might have changed, update its color. self.root = node self.root[3] = 0 def get_min(self, node, stack): while node[0][2] != self.EOT: stack.append((node, 0)) node = node[0] return node def _balance_delete(self, node, r): if node[r ^ 1][r][3] == 0 and node[r ^ 1][r ^ 1][3] == 0: if node[r ^ 1][3] == 0: node[r ^ 1][3] = 1 if node[3] == 0: return node, False node[3] = 0 else: node = self._rotate(node, r) node[r], _ = self._balance_delete(node[r], r) else: if node[r ^ 1][r][3] == 1: node[r ^ 1] = self._rotate(node[r ^ 1], r ^ 1) node = self._rotate(node, r) node[r][3] = 0 node[r ^ 1][3] = 0 return node, True def upper_bound(self, x): """ :return Smallest y satisfying x < y and its leftmost index. If not exists, (EOT, length) will be returned. """ node = self.root y = self.EOT i = node[4] j = 0 while node[2] != self.EOT: if x < node[2]: y = node[2] i = j + node[0][4] node = node[0] else: j += node[0][4] + 1 node = node[1] return y, i def lower_bound(self, x): """ :return Smallest y satisfying x <= y and its leftmost index. If not exists, (EOT, length) will be returned. """ node = self.root y = self.EOT i = node[4] j = 0 while node[2] != self.EOT: if x <= node[2]: y = node[2] i = j + node[0][4] node = node[0] else: j += node[0][4] + 1 node = node[1] return y, i def get_by_index(self, i): """ :return (0-indexed) i-th smallest item. If i is greater than length, EOT will be returned. """ node = self.root if node[4] <= i: return self.EOT j = i while node[2] != self.EOT: left_count = node[0][4] if left_count == j: return node[2] elif left_count > j: node = node[0] else: j -= left_count + 1 node = node[1] def debug_print(self): self._debug_print(self.root, 0) def _debug_print(self, node, depth): if node[2] != self.EOT: self._debug_print(node[0], depth + 1) print(' ' * depth, 'BR'[node[3]], node[2], node[4]) self._debug_print(node[1], depth + 1) def check(): rbt = RedBlackTree() for x in [0, 1, 2, 3, 4, 5, 6, 3, 3, 3, 5, 5, 5, 0, 0, 0]: print('------ INSERT', x, '------') rbt.insert(x) rbt.debug_print() for x in [0, 1, 2, 3, 4, 5, 6, 7]: print('------ UPPER', x, '=', rbt.upper_bound(x)) print('------ LOWER', x, '=', rbt.lower_bound(x)) for x in [5, 0, 3, 3, 0, 5]: print('------ DELETE', x, '------') rbt.delete(x) rbt.debug_print() for i in range(11): print('------ INDEX', i, '=', rbt.get_by_index(i)) if __name__ == '__main__': check()