セグメント木

区間に対する処理をするときによく使われる。$a_1,a_2,\ldots,a_n$の配列に対し、任意の区間$a_i$~$a_j$の合計値や最小値を求めたい時に用いる。

(これを発展させて「区間$a_1$~$a_i$の合計値」に限定したものが、Binary Indexed Treeで、柔軟性と引き替えにメモリや計算効率が向上している)

点に対する更新と、区間に対するクエリ

class SegTreeMin:
    """
    以下のクエリを処理する
    1.update:  i番目の値をxに更新する
    2.get_min: 区間[l, r)の最小値を得る
    """

    def __init__(self, n, INF):
        """
        :param n: 要素数
        :param INF: 初期値(入りうる要素より十分に大きな数)
        """
        self.n = n
        # nより大きい2の冪数
        n2 = 1
        while n2 < n:
            n2 <<= 1
        self.n2 = n2
        self.tree = [INF] * (n2 << 1)
        self.INF = INF

    def update(self, i, x):
        """
        i番目の値をxに更新
        :param i: index(0-indexed)
        :param x: update value
        """
        i += self.n2
        self.tree[i] = x
        while i > 1:
            self.tree[i >> 1] = x = min(x, self.tree[i ^ 1])
            i >>= 1

    def get_min(self, a, b):
        """
        [a, b)の最小値を得る
        :param a: index(0-indexed)
        :param b: index(0-indexed)
        """
        result = self.INF
        
        # (k, l, r)
        # k   : 現在調べている区間のtree内index
        # l, r: kが表す区間の左右端index [l, r)
        q = [(1, 0, self.n2)]
        
        while q:
            k, l, r = q.pop()
            
            if a <= l and r <= b:
                result = min(result, self.tree[k])
                continue
            
            m = (l + r) // 2
            k <<= 1
            if a < m and l < b:
                q.append((k, l, m))
            if a < r and l < m:
                q.append((k + 1, m, r))

        return result

再帰を使った書き方(Python, PyPyでは相対的に遅い)

各値をオブジェクトで持つ場合

class SegmentTree:
    """
    値をObjectで持ち、更新方法を外部から与えられるSegmentTree
     (汎用性と引き替えに、速度がやや犠牲になる)

    update(i, x): iをxで更新
    get(a,b):     [a, b)を取得
    i,a,b は 0-indexed
    """

    def __init__(self, n, init_func, merge_func):
        """
        :param n: 要素数
        :param init_func: init_func() で初期状態の新しいオブジェクトを返す関数
        :param merge_func: merge_func(x, y) でxとyを統合する関数(xを破壊更新する)
        """
        self.n = n
        n2 = 1  # nより大きい2の冪数
        while n2 < n:
            n2 <<= 1
        self.n2 = n2
        self.tree = [init_func() for _ in range(n2 << 1)]
        self.ini = init_func
        self.merge = merge_func

    def update(self, i, x):
        i += self.n2
        self.merge(self.tree[i], x)
        while i > 1:
            self.merge(x, self.tree[i ^ 1])
            i >>= 1
            self.merge(self.tree[i], x)

    def get(self, a, b):
        result = self.ini()
        q = [(1, 0, self.n2)]
        
        while q:
            k, l, r = q.pop()
 
            if a <= l and r <= b:
                self.merge(result, self.tree[k])
                continue
 
            m = (l + r) // 2
            k <<= 1
            if a < m and l < b:
                q.append((k, l, m))
            if a < r and l < m:
                q.append((k + 1, m, r))
 
        return result

区間に対する更新と、区間に対するクエリ

区間足し込みという手法が使える。

programming_algorithm/data_structure/segment_tree.txt · 最終更新: 2019/09/11 by ikatakos
CC Attribution 4.0 International
Driven by DokuWiki Recent changes RSS feed Valid CSS Valid XHTML 1.0