[[赤黒木]]

赤黒木

概要

平衡探索二分木の一種で、以下を満たすように木を保つ。

  • 各ノードを赤と黒で(仮想的に)塗り分ける
  • 根は黒
  • 赤の子は黒(赤は2つ以上連続しない)
  • 根からそれぞれの葉までの経路上にある黒の個数は等しい

こんなことを決めて何がいいのかと、

  • これが満たされている木では、根から最短経路(黒黒黒……)と最長経路(黒赤黒赤……)の差は2倍を超えないので、ほどよい平衡が保たれることが保証される
  • データを追加したり削除したりするとき、親子で繋がった直近3~4ノードの色のパターンを見て塗り替えることで状態を維持できる
    • その中だけで状態を解消できないときは親に遡ることはあるが、再帰的に実装できる

ただ、パターン分けは結構種類があってややこしい。 詳細は参考の1つめのサイトが丁寧に場合分けを説明されている。

実装

再帰関数は極力無くした実装(1箇所だけ、多重再帰にはならない箇所で使用)。どうしても配列アクセスは多くなるので、Pythonだと遅い。PyPyなら……速いとは言えないが、まぁまぁ。

関数返値概要
insert(x)None$x$ を追加(実装によっては重複する値を許さないこともできる。以下の実装では許す)
delete(x)None$x$ を削除(無ければ何もしない)
upper_bound(x)$y,i$$x \lt y$ を満たす最小の $y$ と、その小さい方からのindex $i$ (0-indexed)を取得
lower_bound(x)$y,i$$x \le y$ を満たす最小の $y$ と、その小さい方からのindex $i$ (0-indexed)を取得
get_by_index(i)$x$小さい方から $i$ 番目の値を取得(0-indexed)
debug_print()None木を90°倒して左が根、上が小さい方となるように、“色,値,部分木以下の要素数”を描画

"""
赤黒木
"""


class RedBlackTree:
    # For performance, nodes are implemented as lists.
    #   [left, right, data, color, count]
    #   color: 1:RED  0:BLACK
    #
    # It is useful to set the left/right index to 0/1 for boolean access.
    # The end nodes are BLACK and of which value is EOT (End Of Tree).
    # 
    # Usage example:
    # rbt = RedBlackTree()
    # rbt.insert(x)
    # rbt.delete(x)
    # rbt.upper_bound(x)
    # rbt.lower_bound(x)
    # rbt.get_by_index(i)
    # rbt.debug_print()
    

    def __init__(self, EOT=-1):
        self.EOT = EOT
        self.root = self._leaf()

    def _leaf(self):
        return [None, None, self.EOT, 0, 0]

    def _rotate(self, node, r):
        child = node[r ^ 1]
        node[r ^ 1] = child[r]
        child[r] = node
        child[3] = node[3]
        node[3] = 1
        child_count = child[4]
        child[4] = node[4]
        node[4] -= child_count - node[r ^ 1][4]
        return child

    def insert(self, x):
        stack = []
        node = self.root
        while node[2] != self.EOT:
            # Variation: If the same value is not allowed, return when found.

            to_right = x >= node[2]
            stack.append((node, to_right))
            node = node[to_right]

        # Insert values to the end node
        node[0] = self._leaf()
        node[1] = self._leaf()
        node[2] = x
        node[3] = 1
        node[4] = 1

        # Increase count
        for parent, _ in stack:
            parent[4] += 1

        # Validate tree and rotate if needed.
        while stack:
            parent, r = stack.pop()
            if parent[3] == 1:
                node = parent
                continue
            parent[r] = node
            node, flag = self._balance_insert(parent, r)

            if stack and flag == True:
                parent, r = stack.pop()
                parent[r] = node
                break

        else:
            # If the root might have changed, update its color.
            self.root = node
            self.root[3] = 0

    def _balance_insert(self, node, r):
        flag = True
        if node[r][r ^ 1][3] == 1:
            node[r] = self._rotate(node[r], r)
        if node[r][r][3] == 1:
            if node[r ^ 1][3] == 1:
                node[3] = 1
                node[r ^ 1][3] = node[r][3] = 0
                flag = False
            else:
                node = self._rotate(node, r ^ 1)
        return node, flag

    def delete(self, x):
        node = self.root
        stack = []
        while node[2] != x:
            r = node[2] < x
            stack.append((node, r))
            node = node[r]

        # Not Found
        if node[2] == self.EOT:
            return

        # Node has 2 children: swap min node in right tree and delete min node
        if node[0][2] != self.EOT and node[1][2] != self.EOT:
            stack.append((node, 1))
            min_node = self.get_min(node[1], stack)
            node[2] = min_node[2]
            node = min_node

        # Delete node is root
        if not stack:
            if node[0][2] == self.EOT:
                self.root = node[1]
            else:
                self.root = node[0]
            self.root[3] = 0
            return

        # Decrease count
        for parent, _ in stack:
            parent[4] -= 1

        # Node has 0/1 child
        parent, r = stack[-1]
        if node[0][2] == self.EOT:
            parent[r] = node[1]
            node[1][3] = 0
            # Balance is only needed if both children are EOT and self color is black.
            if node[1][2] != self.EOT or node[3] == 1:
                return
        elif node[1][2] == self.EOT:
            parent[r] = node[0]
            node[0][3] = 0
            return

        # Validate tree and rotate if needed.
        while stack:
            parent, r = stack.pop()
            node, flag = self._balance_delete(parent, r)

            if stack and flag == True:
                parent, r = stack.pop()
                parent[r] = node
                break
        else:
            # If the root might have changed, update its color.
            self.root = node
            self.root[3] = 0

    def get_min(self, node, stack):
        while node[0][2] != self.EOT:
            stack.append((node, 0))
            node = node[0]
        return node

    def _balance_delete(self, node, r):
        if node[r ^ 1][r][3] == 0 and node[r ^ 1][r ^ 1][3] == 0:
            if node[r ^ 1][3] == 0:
                node[r ^ 1][3] = 1
                if node[3] == 0:
                    return node, False
                node[3] = 0
            else:
                node = self._rotate(node, r)
                node[r], _ = self._balance_delete(node[r], r)
        else:
            if node[r ^ 1][r][3] == 1:
                node[r ^ 1] = self._rotate(node[r ^ 1], r ^ 1)
            node = self._rotate(node, r)
            node[r][3] = 0
            node[r ^ 1][3] = 0
        return node, True

    def upper_bound(self, x):
        """
        :return Smallest y satisfying x < y and its leftmost index.
                If not exists, (EOT, length) will be returned.
        """
        node = self.root
        y = self.EOT
        i = node[4]
        j = 0
        while node[2] != self.EOT:
            if x < node[2]:
                y = node[2]
                i = j + node[0][4]
                node = node[0]
            else:
                j += node[0][4] + 1
                node = node[1]
        return y, i

    def lower_bound(self, x):
        """
        :return Smallest y satisfying x <= y and its leftmost index.
                If not exists, (EOT, length) will be returned.
        """
        node = self.root
        y = self.EOT
        i = node[4]
        j = 0
        while node[2] != self.EOT:
            if x <= node[2]:
                y = node[2]
                i = j + node[0][4]
                node = node[0]
            else:
                j += node[0][4] + 1
                node = node[1]
        return y, i

    def get_by_index(self, i):
        """
        :return (0-indexed) i-th smallest item.
                If i is greater than length, EOT will be returned.
        """
        node = self.root
        if node[4] <= i:
            return self.EOT
        j = i
        while node[2] != self.EOT:
            left_count = node[0][4]
            if left_count == j:
                return node[2]
            elif left_count > j:
                node = node[0]
            else:
                j -= left_count + 1
                node = node[1]

    def debug_print(self):
        self._debug_print(self.root, 0)

    def _debug_print(self, node, depth):
        if node[2] != self.EOT:
            self._debug_print(node[0], depth + 1)
            print('      ' * depth, 'BR'[node[3]], node[2], node[4])
            self._debug_print(node[1], depth + 1)


def check():
    rbt = RedBlackTree()
    for x in [0, 1, 2, 3, 4, 5, 6, 3, 3, 3, 5, 5, 5, 0, 0, 0]:
        print('------ INSERT', x, '------')
        rbt.insert(x)
        rbt.debug_print()

    for x in [0, 1, 2, 3, 4, 5, 6, 7]:
        print('------ UPPER', x, '=', rbt.upper_bound(x))
        print('------ LOWER', x, '=', rbt.lower_bound(x))

    for x in [5, 0, 3, 3, 0, 5]:
        print('------ DELETE', x, '------')
        rbt.delete(x)
        rbt.debug_print()

    for i in range(11):
        print('------ INDEX', i, '=', rbt.get_by_index(i))


if __name__ == '__main__':
    check()


本WebサイトはcookieをPHPのセッション識別および左欄目次の開閉状況記憶のために使用しています。同意できる方のみご覧ください。More information about cookies
programming_algorithm/data_structure/balancing_binary_search_tree/redblacktree.txt · 最終更新: 2019/11/28 by ikatakos
CC Attribution 4.0 International
Driven by DokuWiki Recent changes RSS feed Valid CSS Valid XHTML 1.0