[[赤黒木]]

赤黒木

概要

平衡探索二分木の一種で、以下を満たすように木を保つ。

  • 各ノードを赤と黒で(仮想的に)塗り分ける
  • 根は黒
  • 赤の子は黒(赤は2つ以上連続しない)
  • 根からそれぞれの葉までの経路上にある黒の個数は等しい

こんなことを決めて何がいいのかと、

  • これが満たされている木では、根から最短経路(黒黒黒……)と最長経路(黒赤黒赤……)の差は2倍を超えないので、ほどよい平衡が保たれることが保証される
  • データを追加したり削除したりするとき、親子で繋がった直近3~4ノードの色のパターンを見て塗り替えることで状態を維持できる
    • その中だけで状態を解消できないときは親に遡ることはあるが、再帰的に実装できる

ただ、パターン分けは結構種類があってややこしい。 詳細は参考の1つめのサイトが丁寧に場合分けを説明されている。

実装

再帰関数は極力無くした実装(1箇所だけ、多重再帰にはならない箇所で使用)。どうしても配列アクセスは多くなるので、Pythonだと遅い。PyPyなら……速いとは言えないが、まぁまぁ。

関数返値概要
insert(x)None$x$ を追加(実装によっては重複する値を許さないこともできる。以下の実装では許す)
delete(x)None$x$ を削除(無ければ何もしない)
upper_bound(x)$y,i$$x \lt y$ を満たす最小の $y$ と、その小さい方からのindex $i$ (0-indexed)を取得
lower_bound(x)$y,i$$x \le y$ を満たす最小の $y$ と、その小さい方からのindex $i$ (0-indexed)を取得
get_by_index(i)$x$小さい方から $i$ 番目の値を取得(0-indexed)
debug_print()None木を90°倒して左が根、上が小さい方となるように、“色,値,部分木以下の要素数”を描画

"""
赤黒木
"""


class RedBlackTree:
    # For performance, nodes are implemented as lists.
    #   [left, right, data, color, count]
    #   color: 1:RED  0:BLACK
    #
    # It is useful to set the left/right index to 0/1 for boolean access.
    # The end nodes are BLACK and of which value is EOT (End Of Tree).
    # 
    # Usage example:
    # rbt = RedBlackTree()
    # rbt.insert(x)
    # rbt.delete(x)
    # rbt.upper_bound(x)
    # rbt.lower_bound(x)
    # rbt.get_by_index(i)
    # rbt.debug_print()
    

    def __init__(self, EOT=-1):
        self.EOT = EOT
        self.root = self._leaf()

    def _leaf(self):
        return [None, None, self.EOT, 0, 0]

    def _rotate(self, node, r):
        child = node[r ^ 1]
        node[r ^ 1] = child[r]
        child[r] = node
        child[3] = node[3]
        node[3] = 1
        child_count = child[4]
        child[4] = node[4]
        node[4] -= child_count - node[r ^ 1][4]
        return child

    def insert(self, x):
        stack = []
        node = self.root
        while node[2] != self.EOT:
            # Variation: If the same value is not allowed, return when found.

            to_right = x >= node[2]
            stack.append((node, to_right))
            node = node[to_right]

        # Insert values to the end node
        node[0] = self._leaf()
        node[1] = self._leaf()
        node[2] = x
        node[3] = 1
        node[4] = 1

        # Increase count
        for parent, _ in stack:
            parent[4] += 1

        # Validate tree and rotate if needed.
        while stack:
            parent, r = stack.pop()
            if parent[3] == 1:
                node = parent
                continue
            parent[r] = node
            node, flag = self._balance_insert(parent, r)

            if stack and flag == True:
                parent, r = stack.pop()
                parent[r] = node
                break

        else:
            # If the root might have changed, update its color.
            self.root = node
            self.root[3] = 0

    def _balance_insert(self, node, r):
        flag = True
        if node[r][r ^ 1][3] == 1:
            node[r] = self._rotate(node[r], r)
        if node[r][r][3] == 1:
            if node[r ^ 1][3] == 1:
                node[3] = 1
                node[r ^ 1][3] = node[r][3] = 0
                flag = False
            else:
                node = self._rotate(node, r ^ 1)
        return node, flag

    def delete(self, x):
        node = self.root
        stack = []
        while node[2] != x:
            r = node[2] < x
            stack.append((node, r))
            node = node[r]

        # Not Found
        if node[2] == self.EOT:
            return

        # Node has 2 children: swap min node in right tree and delete min node
        if node[0][2] != self.EOT and node[1][2] != self.EOT:
            stack.append((node, 1))
            min_node = self.get_min(node[1], stack)
            node[2] = min_node[2]
            node = min_node

        # Delete node is root
        if not stack:
            if node[0][2] == self.EOT:
                self.root = node[1]
            else:
                self.root = node[0]
            self.root[3] = 0
            return

        # Decrease count
        for parent, _ in stack:
            parent[4] -= 1

        # Node has 0/1 child
        parent, r = stack[-1]
        if node[0][2] == self.EOT:
            parent[r] = node[1]
            node[1][3] = 0
            # Balance is only needed if both children are EOT and self color is black.
            if node[1][2] != self.EOT or node[3] == 1:
                return
        elif node[1][2] == self.EOT:
            parent[r] = node[0]
            node[0][3] = 0
            return

        # Validate tree and rotate if needed.
        while stack:
            parent, r = stack.pop()
            node, flag = self._balance_delete(parent, r)

            if stack and flag == True:
                parent, r = stack.pop()
                parent[r] = node
                break
        else:
            # If the root might have changed, update its color.
            self.root = node
            self.root[3] = 0

    def get_min(self, node, stack):
        while node[0][2] != self.EOT:
            stack.append((node, 0))
            node = node[0]
        return node

    def _balance_delete(self, node, r):
        if node[r ^ 1][r][3] == 0 and node[r ^ 1][r ^ 1][3] == 0:
            if node[r ^ 1][3] == 0:
                node[r ^ 1][3] = 1
                if node[3] == 0:
                    return node, False
                node[3] = 0
            else:
                node = self._rotate(node, r)
                node[r], _ = self._balance_delete(node[r], r)
        else:
            if node[r ^ 1][r][3] == 1:
                node[r ^ 1] = self._rotate(node[r ^ 1], r ^ 1)
            node = self._rotate(node, r)
            node[r][3] = 0
            node[r ^ 1][3] = 0
        return node, True

    def upper_bound(self, x):
        """
        :return Smallest y satisfying x < y and its leftmost index.
                If not exists, (EOT, length) will be returned.
        """
        node = self.root
        y = self.EOT
        i = node[4]
        j = 0
        while node[2] != self.EOT:
            if x < node[2]:
                y = node[2]
                i = j + node[0][4]
                node = node[0]
            else:
                j += node[0][4] + 1
                node = node[1]
        return y, i

    def lower_bound(self, x):
        """
        :return Smallest y satisfying x <= y and its leftmost index.
                If not exists, (EOT, length) will be returned.
        """
        node = self.root
        y = self.EOT
        i = node[4]
        j = 0
        while node[2] != self.EOT:
            if x <= node[2]:
                y = node[2]
                i = j + node[0][4]
                node = node[0]
            else:
                j += node[0][4] + 1
                node = node[1]
        return y, i

    def get_by_index(self, i):
        """
        :return (0-indexed) i-th smallest item.
                If i is greater than length, EOT will be returned.
        """
        node = self.root
        if node[4] <= i:
            return self.EOT
        j = i
        while node[2] != self.EOT:
            left_count = node[0][4]
            if left_count == j:
                return node[2]
            elif left_count > j:
                node = node[0]
            else:
                j -= left_count + 1
                node = node[1]

    def debug_print(self):
        self._debug_print(self.root, 0)

    def _debug_print(self, node, depth):
        if node[2] != self.EOT:
            self._debug_print(node[0], depth + 1)
            print('      ' * depth, 'BR'[node[3]], node[2], node[4])
            self._debug_print(node[1], depth + 1)


def check():
    rbt = RedBlackTree()
    for x in [0, 1, 2, 3, 4, 5, 6, 3, 3, 3, 5, 5, 5, 0, 0, 0]:
        print('------ INSERT', x, '------')
        rbt.insert(x)
        rbt.debug_print()

    for x in [0, 1, 2, 3, 4, 5, 6, 7]:
        print('------ UPPER', x, '=', rbt.upper_bound(x))
        print('------ LOWER', x, '=', rbt.lower_bound(x))

    for x in [5, 0, 3, 3, 0, 5]:
        print('------ DELETE', x, '------')
        rbt.delete(x)
        rbt.debug_print()

    for i in range(11):
        print('------ INDEX', i, '=', rbt.get_by_index(i))


if __name__ == '__main__':
    check()


programming_algorithm/data_structure/balancing_binary_search_tree/redblacktree.txt · 最終更新: 2019/11/28 by ikatakos
CC Attribution 4.0 International
Driven by DokuWiki Recent changes RSS feed Valid CSS Valid XHTML 1.0