std::setを使わない代替テクニック
競技プログラミングでは、C++ に std::set という(事実上)平衡二分探索木ライブラリがあるから、たまにその使用を前提とした問題が出る。
std::set, multiset は、集合を順序を持って管理するデータ構造であり、以下のことができる(set - cpprefjp C++日本語リファレンス より一部抜粋)
機能 | 概略 | 計算量 |
insert | 集合に要素を追加 | $O(\log{N})$ |
erase | 集合から要素を削除 | $O(\log{N})$ |
find | 値の検索 | $O(\log{N})$ |
upper_bound
lower_bound | ある値より大きい最小の要素の検索
ある値以上の最小の要素の検索 | $O(\log{N})$ |
Pythonでは標準ライブラリにないし、実装してもリストアクセスの遅さからTLEに直面することがままある。
いや、実際には、別に使わなくても何とか別の方法がある場合が多いが、std::setが「使えば一発」なのに対し、問題依存の考察や実装の手間を要するのが常である。
(素直にC++
で書けって?そりゃそうだ)
std::setを使いたいけど使えない……という場合に、代わりになりうる方法をメモ。
2023/08 の言語アップデートにより、AtCoderでは、sortedcontainers ライブラリが入るようになった。
これが上記の機能を持っており、十分高速に動作するため、ライブラリにない問題は解消されたといっていい。
ただし、他のコンテストサイトでは基本的に入っていないので、依然として以下の方法が必要になる場合もある。
代替方法一覧
問題に出くわすごとに列挙していき、ある程度蓄積されたら分類方法を考える。
Binary Indexed Tree
汎用性の高い置き換え。
できること
値の追加・削除 $O(\log{N})$
$k$ 番目の値を求める $O(\log{N})$
$x$ が存在するか検索 $O(\log{N})$
$x$ より小さい・大きい直近の値を検索 $O(\log{N})$
Binary Indexed Tree で管理できる。ただし、BIT上で累積和の二分探索を実装する必要がある。
詳細はBinary Indexed Tree(Fenwick Tree)参照。
以下のように置きかえる。
値 $x$ を追加 → $bit.add(x, 1)$
値 $x$ を削除 → $bit.add(x, -1)$
$k$ 番目の値を取得 → $bit.lower\_bound(k)$(累積和が $k$ 以上になる最小のindexを取得)
$x$ の検索 → $bit.sum(x)-bit.sum(x-1) \gt 0$
$x$ 以下の最大の値を取得 $bit.lower\_bound(bit.sum(x))$
$x$ 以上の最小の値を取得 $bit.lower\_bound(bit.sum(x - 1) + 1)$
$x$ 以下の最大の値が存在しない場合、$1$ が返る。
$x$ 以上の最小の値が存在しない場合、$N+1$ が返る。
(※BITを1-indexedで実装していた場合。lower_boundの実装内容にも依るが)
そのような操作が発生しうる場合、indexをずらして、$i=1$ は空白要素として $i=2$ から本来の値を入れる、などで本当の値と区別できる。
例
累積和 std::set
初期状態 1 2 3 4 5 6 7 8
[0 0 0 0 0 0 0 0] {}
5を追加
[0 0 0 0 1 1 1 1] {5}
3を追加
[0 0 1 1 2 2 2 2] {3, 5}
小さい方から2番目を取得
[0 0 1 1 2 2 2 2] {3, 5}
^ 累積和で"2"が始まるiは「5」
2,3,7を追加
[0 1 3 3 4 4 5 5] {2, 3, 3, 5, 7}
4以下の最大の値を取得
[0 1 3 3 4 4 5 5] {2, 3, 3, 5, 7}
~ 累積和で 4 の箇所は"3"
^ 累積和で "3" が始まるiは「3」
6以上の最小の値を取得
[0 1 3 3 4 4 5 5] {2, 3, 3, 5, 7}
~ 累積和で 6-1=5 の箇所は"4"
^ 累積和で 4+1=5 が始まるiは「7」
7を削除
[0 1 3 3 4 4 4 4] {2, 3, 3, 5}
6以上の最小の値を取得
[0 1 3 3 4 4 4 4] {2, 3, 3, 5}
^ 存在しない場合、N+1 が返る点には注意
バリエーション
値の範囲がこれを越える場合
値の取り得る種類数が $N=10^6$ 程度で、かつ先読みできる場合、前もって圧縮して $1~N$ に対応づければよい。
前後の値を取得
BIT上の二分探索では、累積和が $k$ 以上になる要素そのものは分かっても、前後の値はたどれない。
この方法では前後の値を求めたい場合は $k \pm 1$ で累積和の二分探索を改めて行うしかない(と思う)。
ただし、小さい方の値を求める際、以下の方法で少しだけ省略できる場合がある。
累積和の探索では、その過程で、累積和が $k$ 以上になる最小の要素を $x$ として「$x$ 未満の要素の個数」も求めることになる。
(これを $L[k]$ と表す)
優先度付きキュー
2つの優先度付きキューを用いることで代替できる。
Binary Indexed Treeでは実現できない、$10^9$ など配列を作るのが厳しい範囲のクエリや、整数値以外にも対応できる。
答えは常にlowerの先頭に位置する($x$ とする)。新しい値 $y$ が来たとき、
バリエーション
$k$ 番目でなく中央値を求めよ
lowerとupperの要素数が同じになるように保ちながら追加していく。
$k$ 番目の値を削除せよ
削除されるのが $k$ 番目の値に限定されるなら、つねに優先度キューの先頭にあるので可能。
任意の値を削除せよ
優先度キューとは別に、「値が生き残っているか」を管理するデータを用意する。
要素に重複が無い場合
要素に重複がある場合
lowerで生き残っている要素数を管理する変数 $lc$(lowerの個数が $k$ 未満になる可能性のある場合)
lower, upperで生き残っている要素数を管理する変数 $lc, uc$(中央値を求める場合)
値 $y$ を削除するとき、lowerの先頭 $x$ と比較して、
popするとき、それが生き残っている値かどうかを確認し、削除済みなら無かったこととして続けてpopする。
配列・特殊アイデア系①
一概には言えないけど、だいたい以下のような問題に適用できうる。
Binary Indexed Tree でもいいんだけど、より高速で、なかなかアイデアが綺麗な解法。
これを使って解ける問題(ネタバレ注意)
std::set解(想定解)
std::setであれば、値の小さい順に「要素のindex」をinsertしていく方法が考えられる。
ある要素 $A_i$ を挿入しようとしたとき、既にset内にあるのは自身より小さな要素のindexに限られる。
そこから $i$ 未満で最も大きい数を検索すれば、それが $A_i$ にとって、自身より小さな左で直近の要素の位置である。
代替手法
各要素の暫定解を表す配列を用意して、値の大きい順に処理して更新していく。
ある要素の処理が終わったら、後から処理される要素のために、自身をスキップさせるよう更新する。
横一列に繋がったグラフの辺を繋ぎなおすと考えるとわかりやすい。
【初期状態】
i 1 2 3 4 5 6
③--①--④--⑥--⑤--②
L 0 1 2 3 4 5 自身より左で自身より小さい要素の i (暫定)
R 2 3 4 5 6 7 自身より右で自身より小さい要素の i (暫定)
はじめはすぐ左・すぐ右の位置を入れておく
【最も大きい要素⑥】
処理された時点のL,Rの値がそのまま答えとなる。
最初はどの要素も自分より小さいので、すぐ左・すぐ右のindexとなるのは当然。
i 1 2 3 4 5 6
③--①--④--⑥--⑤--②
L 0 1 2 [3] 4 5
R 2 3 4 [5] 6 7
次の処理のために、⑥を抜いて、④と⑤を直結させたい。
i 1 2 3 4 5 6
③--①--④------⑤--②
L 0 1 2 _ 5
R 2 3 _ 6 7
そのためには、以下のように更新すればよい
L[R[i]] = L[i]
R[L[i]] = R[i]
今回に当てはめると、
L[5] = 3
L[3] = 5
i 1 2 3 4 5 6
③--①--④------⑤--②
L 0 1 2 [3] 5
R 2 3 [5] 6 7
【2番目に大きい要素⑤】
同様に、この時点での L[i]=3, R[i]=6 が⑤にとっての答え
i 1 2 3 4 5 6
③--①--④------⑤--②
L 0 1 2 [3] 5
R 2 3 5 [6] 7
同様に更新すると、ちゃんと④と②がつながる
i 1 2 3 4 5 6
③--①--④----------②
L 0 1 2 [3]
R 2 3 [6] 7
これを繰り返すと、$O(N)$ で求められる。std::setを使うと $O(N\log{N})$ なので、それより高速である。
バリエーション
順列で無い場合(同じ要素は含まれない)
値が飛び飛びなら、先読みできるなら圧縮すればよい。先読みできない場合は……わからん。
同じ要素が含まれる場合
同じ値が含まれると、同率の処理が上手くいかない。
一応、自身の値 “以下” の直近の位置を求めればいいなら、方法がある。
自身の値 “未満” なら……どうやればいいだろうね。
配列・特殊アイデア系②
①と似ているが、クエリ2で同じ $i$ について何回も聞かれる点が異なる。
(①は既に処理が終わった $i$ は更新しないので、その後に書かれている値は正しくないものになっていく)
std::set解
最初、setに $1~N$ の値を全て入れておく。
クエリ1に対しては、setから $i$ をdeleteする。
クエリ2に対しては、lower_boundが使える。
または、setに入れる値を値が残っている区間 $[L,R)$ にして、取り除く操作を区間の分割で表現することもできる。
代替手法
取り除くのみで、加えることはしない場合に限られる。
「整数が存在するか」と
「存在しなかった場合に確認する次の整数」を管理する長さNの配列を用意
i 1 2 3 4 5 6 7 8 9
[ 1 1 1 1 1 1 1 1 1 ] 0:存在しない 1:する
[ 2 3 4 5 6 7 8 9 x ] 次の整数
クエリ2: i=7 に対する答え
[ 1 1 1 1 1 1 [1] 1 1 ] 7が存在するのでそのまま7が答え
[ 2 3 4 5 6 7 8 9 x ]
クエリ1: i=4,5 を取り除く
[ 1 1 1 [0 0] 1 1 1 1 ] ←該当位置を0にする
[ 2 3 4 5 6 7 8 9 x ] ←クエリ1では特に何もせずそのまま
クエリ2: i=4 に対する答え
[ 1 1 1 [0] 0 1 1 1 1 ] i=4がないので
[ 2 3 4 [5] 6 7 8 9 x ] 次に確認すべき5を見に行く
[ 1 1 1 0 [0] 1 1 1 1 ] i=5もないので
[ 2 3 4 5 [6] 7 8 9 x ] 次の6を見に行く
[ 1 1 1 0 0 [1] 1 1 1 ] i=6はあったので、6が答え
[ 2 3 4 5 6 7 8 9 x ]
後処理
[ 1 1 1 0 0 1 1 1 1 ] 辿ってきたiについて、
[ 2 3 4 [6 6] 7 8 9 x ] 存在が確認できた整数に更新しておく
クエリ1: i=6 を取り除く
[ 1 1 1 0 0 [0] 1 1 1 ] ←該当位置を0にする
[ 2 3 4 6 6 7 8 9 x ]
クエリ2: i=4 に対する答え
[ 1 1 1 [0] 0 [0][1] 1 1 ] 同様に4→6→7と辿って、
[ 2 3 4 [6] 6 [7] 8 9 x ] 7があったので7が答え
後処理
[ 1 1 1 0 0 0 1 1 1 ]
[ 2 3 4 [7] 6 [7] 8 9 x ] 更新
一度辿った値についてはスキップできるので、償却するとクエリ2を1回あたり $O(\log{N})$ で行えるようになる。
バリエーション
値が復活する
ピボット木
上記で紹介されている平衡二分探索木。結構速いみたい。
メリットとして、入れる値が整数なら上限値が $10^9$ などのように大きくても座標圧縮が必要ない。
(BIT木のように最初に配列を作るのでは無く、挿入ごとにノードインスタンスを作るので)
アイデアとしては、各ノードは、自身の位置に応じたpivot値を持つ。
たとえば値 $x$ を挿入するとき「上から挿入箇所を辿っていって、pivot値が $x$ と等しいノードに来たら $x$ はそこでFIXし、代わりにそこに入っていた元の値の挿入箇所をその地点から探し始める」ようなことを行う。
これにより、高さが最初に決めた $K$ 以上にならないことが保証されるとともに、他の平衡二分探索木で発生する「回転操作」、つまりは左右の子を繋ぎ替えるなど、値の書き換えを減らしている。
バリエーション
実数値を載せる
ピボット値を小数範囲に拡張して、$0.5, 0.25, 0.125, ...$ 単位まで探索できるようにする。
ただし、探索深さの上限が保証されなくなる。(まぁ、極端に狭い範囲に多くの値が集中しない限り実用的に問題となることは少ないはず)
複数の整数のTupleを載せる
bitshiftして1整数にまとめてやる。
値の上限が大きくなると計算量も増えるが、例えば上限 $10^6$ の整数を3つ管理しても約 $2^{60}$、$K=60$ なので、そこまで無理なものではない。
同じ値を持つ
辞書などで各値の個数を管理して、削除時に個数が“0”にならない限りノードは削除しない、みたいな拡張でいけそう。
$k$ 番目を求める
ノードに自身の部分木以下のノードの個数を持たせておく。
同じ値を持て、$k$ 番目も求められるようにした
class BalancingPivotTree:
"""
元とさせていただいたアイデア・コード
https://qiita.com/Kiri8128/items/6256f8559f0026485d90
・個数を管理し、多重集合を管理できるようにした
・K番目を取得できるようにした
Features:
append(v)
delete(v)
get_kth(k)
find_l(v)
find_r(v)
min
max
"""
def __init__(self, n):
self.N = n
self.root = self.node(1 << n, 1 << n)
self.count = {1 << n: 1}
def append(self, v):
""" v (0 <= v <= 2^n-2) を追加 """
v += 1
if v in self.count:
self.count[v] += 1
else:
self.count[v] = 1
inc = 1
nd = self.root
while True:
nd.subtree_count += inc
if v == nd.value:
# v がすでに存在する場合に何か処理が必要ならここに書く
return 0
else:
mi, ma = min(v, nd.value), max(v, nd.value)
if mi < nd.pivot:
if nd.value != ma:
inc = self.count[mi]
nd.value = ma
if nd.left:
nd = nd.left
v = mi
else:
p = nd.pivot
nd.left = self.node(mi, p - (p & -p) // 2)
nd.left.subtree_count = self.count[mi]
break
else:
if nd.value != mi:
inc = self.count[ma]
nd.value = mi
if nd.right:
nd = nd.right
v = ma
else:
p = nd.pivot
nd.right = self.node(ma, p + (p & -p) // 2)
nd.right.subtree_count = self.count[ma]
break
def leftmost(self, nd):
if nd.left: return self.leftmost(nd.left)
return nd
def rightmost(self, nd):
if nd.right: return self.rightmost(nd.right)
return nd
def find_l(self, v):
""" vより真に小さいやつの中での最大値(なければ-1) """
v += 1
nd = self.root
prev = 0
if nd.value < v: prev = nd.value
while True:
if v <= nd.value:
if nd.left:
nd = nd.left
else:
return prev - 1
else:
prev = nd.value
if nd.right:
nd = nd.right
else:
return prev - 1
def find_r(self, v):
""" vより真に大きいやつの中での最小値(なければRoot) """
v += 1
nd = self.root
prev = 0
if nd.value > v: prev = nd.value
while True:
if v < nd.value:
prev = nd.value
if nd.left:
nd = nd.left
else:
return prev - 1
else:
if nd.right:
nd = nd.right
else:
return prev - 1
@property
def max(self):
return self.find_l((1 << self.N) - 1)
@property
def min(self):
return self.find_r(-1)
def delete(self, v, nd=None, prev=None, dec=1):
""" 値がvの要素を1個削除(なければ何もしない) """
v += 1
needs_delete = True
if nd is None:
if v not in self.count:
return
elif self.count[v] == 1:
del self.count[v]
else:
self.count[v] -= 1
needs_delete = False
nd = self.root
if prev is None:
prev = nd
while v != nd.value:
prev = nd
if v <= nd.value:
if nd.left:
nd.subtree_count -= dec
nd = nd.left
else:
#####
return
else:
if nd.right:
nd.subtree_count -= dec
nd = nd.right
else:
#####
return
nd.subtree_count -= dec
if not needs_delete:
return
if (not nd.left) and (not nd.right):
if not prev.left:
prev.right = None
elif not prev.right:
prev.left = None
else:
if nd.pivot == prev.left.pivot:
prev.left = None
else:
prev.right = None
elif nd.right:
# print("type A", v)
nd.value = self.leftmost(nd.right).value
self.delete(nd.value - 1, nd.right, nd, self.count[nd.value])
else:
# print("type B", v)
nd.value = self.rightmost(nd.left).value
self.delete(nd.value - 1, nd.left, nd, self.count[nd.value])
def get_kth(self, k: int):
"""
k番目を取得。現存する要素数より大きい数を指定すると-1
:param k:
:return:
"""
nd = self.root
if nd.subtree_count - 1 < k:
return -1
while True:
cnt = self.count[nd.value]
if nd.left is None:
if k <= cnt:
return nd.value - 1
assert nd.right is not None
k -= cnt
nd = nd.right
else:
if nd.left.subtree_count >= k:
nd = nd.left
elif nd.left.subtree_count + cnt >= k:
return nd.value - 1
else:
assert nd.right is not None
k -= nd.left.subtree_count + cnt
nd = nd.right
def __contains__(self, v: int) -> bool:
return v + 1 in self.count
class node:
def __init__(self, v, p):
self.value = v
self.pivot = p
self.left = None
self.right = None
self.subtree_count = 1
def __repr__(self):
lch = self.left.value - 1 if self.left else -1
rch = self.right.value - 1 if self.right else -1
return f'({self.value - 1}, {self.pivot}, {lch}, {rch}, {self.subtree_count})'
def debug(self):
def debug_node(nd):
re = []
if nd.left:
re += debug_node(nd.left)
if nd.value: re.append(str(nd))
if nd.right:
re += debug_node(nd.right)
return re
print("Debug - root =", self.root.value - 1, debug_node(self.root)[:50])
print('Debug ', self.count)
def debug_list(self):
def debug_node(nd):
re = []
if nd.left:
re += debug_node(nd.left)
if nd.value:
re.extend([nd.value - 1] * self.count[nd.value])
if nd.right:
re += debug_node(nd.right)
return re
return debug_node(self.root)[:-1]
def debug_count(self, nd=None):
if nd is None:
nd = self.root
lch = nd.left.subtree_count if nd.left is not None else 0
rch = nd.right.subtree_count if nd.right is not None else 0
if nd.subtree_count != lch + rch + self.count[nd.value]:
print('NG!!', nd, lch, rch, self.count[nd.value])
print(self.debug_list())
print(self.count)
if lch != 0:
self.debug_count(nd.left)
if rch != 0:
self.debug_count(nd.right)
bpt = BalancingPivotTree(5) # 0-30の要素を入れられる
bpt.append(3)
bpt.append(20)
bpt.append(5)
bpt.append(10)
bpt.append(5)
bpt.append(13)
bpt.append(20)
bpt.append(3)
print(bpt.debug_list())
assert (3 in bpt) == True
assert (4 in bpt) == False
assert (5 in bpt) == True
assert bpt.find_l(12) == 10
assert bpt.find_l(13) == 10
assert bpt.find_l(14) == 13
assert bpt.find_r(3) == 5
assert bpt.find_r(4) == 5
assert bpt.find_r(5) == 10
assert bpt.find_r(6) == 10
assert bpt.min == 3
assert bpt.max == 20
assert bpt.get_kth(1) == 3
assert bpt.get_kth(2) == 3
assert bpt.get_kth(3) == 5
assert bpt.get_kth(4) == 5
assert bpt.get_kth(5) == 10
assert bpt.get_kth(6) == 13
assert bpt.get_kth(7) == 20
assert bpt.get_kth(8) == 20
assert bpt.get_kth(9) == -1
bpt.delete(20)
print(bpt.debug_list())
bpt.delete(3)
print(bpt.debug_list())
bpt.delete(10)
print(bpt.debug_list())
bpt.delete(20)
print(bpt.debug_list())
bpt.delete(3)
print(bpt.debug_list())
bpt.delete(5)
print(bpt.debug_list())
assert bpt.find_l(5) == -1
assert bpt.find_l(6) == 5
assert bpt.find_r(12) == 13
assert bpt.find_r(13) == 31
assert bpt.min == 5
assert bpt.max == 13
assert bpt.get_kth(1) == 5
assert bpt.get_kth(2) == 13
assert bpt.get_kth(3) == -1
print()
# 愚直チェック
from random import randrange
from bisect import insort
bpt = BalancingPivotTree(6) # 0 ~ 62 までの要素を入れられるピボット木
S = []
for _ in range(10000):
a = randrange(63)
if randrange(2) == 0:
print(f'append {a}')
bpt.append(a)
insort(S, a)
else:
print(f'delete {a}')
bpt.delete(a)
if a in S:
S.remove(a)
if bpt.debug_list() != S:
print('NG!! Arrays are not same.')
print('BT:', bpt.debug_list())
print('LS:', S)
elif len(S) > 0:
k = randrange(len(S))
bpt_k = bpt.get_kth(k + 1)
if bpt_k != S[k]:
print(f'NG!! k({k + 1})th item is wrong.')
print(f'BT: {bpt_k} vs LS: {S[k]}')
bpt.debug_count()
print("END")
平方分割で管理
そもそも単なるリスト(配列)でも、ソートされていれば検索は二分探索を使えば $O(\log{N})$ でいける。
計算量的に問題となるのは、ソート状態を保ちつつの要素の挿入や削除。要素を1つずつずらすのに $O(N)$ かかる。
[ 2 3 5 9 ] 4 を挿入
↘ ↘
[ 2 3 4 5 9 ] 入ってる要素数、または後ろにある要素数だけ値のコピーが発生
これを、「$\sqrt{N}$ 個のリストに $\sqrt{N}$ 個の要素が入った状態」として管理することで、
挿入・削除を平均 $O(\sqrt{N})$ にしてしまおうというアイデア。
どのリストに入っているかの検索に新たなコストが発生するものの、そこそこ高速に動作する。
挿入・削除を繰り返すと徐々に要素数にばらつきが出てくるので、定期的に再構築する。
またデータ型が整数値など単純なものであれば、listを使うよりarray.arrayを使った方が高速に動作する?
メリットとして、巨大な値、小数値、タプル、その他も大小が定義できれば、そのまま入れられる。