Numba Library
競技プログラミングのNumba用スニペットを、作成するたびに追記していく。
基本的には素のPythonとあまり変わらず、せいぜいListをnumpy.ndarrayに置きかえただけのものがほとんどだが、 普通に実装するとNumbaでは使えない機能を踏んでしまうものも一部あり、それを避けた実装としてNumba用にまとまってた方が嬉しい。
- classで作ると事前コンパイルしにくい(できない?)ので、複数の関数を用意する形にする
- クロージャとして実装(
@njitする大枠の関数の中に記述)すれば、各関数を逐一@njitする必要は無い - Numbaのクロージャは再帰ができないので、再帰を用いた実装はなるべく避ける
- どうしても再帰で書かざるを得ない or 再帰の方がアレンジしやすい場合は、単独の関数としてコンパイルする
- その場合は
@njit(型名)を並記する
クラス変数の代替方法
ちょっと長くなるが、Numba用に移植する際の問題点の1つへの対処法に関する考察を書いておく。
Numbaでは、事前コンパイルをする場合はクラスが使えないのが悩ましい。
おかげでインスタンス変数、つまりは状態を持てないので、常に外部から注入する必要がある。
使う際に何を注入するか意識しなければならないし、管理したい変数が増えると記述もどんどん冗長になる。
class UnionFind:
def __init(self, n):
self.table = [-1] * n # ←インスタンス変数
def unite(self, a, b):
# ...略
# 使う際にはtableは隠蔽され、中でどう使われているかなんて気にしなくてよい
uft = UnionFind(10)
uft.unite(1, 5)
uft.unite(2, 6)
@njit
def main():
def unite(table, a, b):
# 略
def find(table, a, b):
# 略
# 使う際は、外部でtableを定義し、毎回連れ回す必要が生じる
table = [-1] * 10
unite(table, 1, 5)
unite(table, 2, 6)
また、もう1つの問題点として、Numbaは内部関数も含め、nested(入れ子)なリスト等を引数に取れない。Numpyの多次元配列ならOKだが、それでは表現できないものもある。
(上手くいくこともある? 条件調査中)
@njit
def main():
def something_function(nested_list, a, b):
# コンパイル時エラー
nested_list = [[0] for _ in range(10)]
something_function(nested_list, 1, 5)
Numba関数の中で、nestedなリストを作ることはできる。また、同じ関数内にクロージャ関数を作れば、クロージャ関数から関数外のリストを参照することができる。
これを用いて、以下のようにすれば、something_function の中で nested_list が使える。
@njit
def main():
nested_list = [[0] for _ in range(10)] # 関数より先にnonlocalな変数を定義
def something_function(a, b):
nested_list[a][b] = 1 # 変数を使う
something_function(1, 5)
しかし、それでは関数と状態が1対1で結びついてしまい、クラスにおける「複数のインスタンスを作る」ようなことができない。
あまり綺麗ではないが、無理矢理解決するとしたら、以下のようになるだろうか。
関数外部にはインスタンス別のリストを記録する NESTED_LIST を用意し、init() ではそこに初期化したリストを加える(これが1つのインスタンス変数となる)。
init() は自身のIDを返すので、以降、something_function() など他の関数を呼ぶ際は、そのIDのみを連れ回す。
これなら、管理したい変数が増えても使う側で管理するのはIDのみで済み、極力、内部実装を意識しないで使える。
ちなみに、init()では、入れ子リストの中身が空だと型推定ができず、コンパイルエラーとなってしまう(9~14行目)。
ちょっと奇妙だが、入れる予定の型が分かるような書き方で lst を定義し、それを空にした後コピーするようにすると上手くいく。
@njit
def main():
NESTED_LIST = []
def init(n):
_id = len(NESTED_LIST)
# × コンパイルエラー
# NESTED_LIST.append([[] for _ in range(n)])
lst = [0]
lst.clear()
NESTED_LIST.append([lst.copy() for _ in range(n)])
return _id
def something_function(_id, a, b):
nested_list = NESTED_LIST[_id]
nested_list[a][b] = 1
id1 = something_init(10)
id2 = something_init(20)
something_function(id1, 1, 5)
something_function(id2, 2, 6)
あくまでコンパイルが通るというだけで、もっといい書き方があるなら使いたい。
実装例
bit count
2進数表記で'1'の立っている数。
def bit_count(x):
x = (x & 0x55555555) + ((x >> 1) & 0x55555555)
x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
x = (x & 0x0F0F0F0F) + ((x >> 4) & 0x0F0F0F0F)
x = (x & 0x00FF00FF) + ((x >> 8) & 0x00FF00FF)
x = (x & 0x0000FFFF) + ((x >> 16) & 0x0000FFFF)
return x
bit length
2進数表記の桁数(0は0)
def bit_length(n):
ret = 0
while n > 0:
n >>= 1
ret += 1
return ret
なお、$n$ が1以上かを確認するのに while n: としても素のPythonは通るが、Numbaではバージョンにより $n$ をbool値だと推定してコンパイルしてしまう。
そうなると、引数にどんな正整数を渡しても $n=1$ となってしまい、おかしくなる。ちゃんと int 型であることがわかるような書き方をする。
mod累乗
$x^a$ を MOD で割った剰余。pythonなら pow(x, a, MOD) で求められるが、Numbaでは第3引数が未対応。二分累乗法で実装。
def mod_pow(x, a, MOD):
ret = 1
cur = x
while a > 0:
if a & 1:
ret = ret * cur % MOD
cur = cur * cur % MOD
a >>= 1
return ret
mod階乗と逆数の事前計算
$0!~N!$ とそのモジュラ逆数を計算。上記の mod_pow を使用。
前提として、$n \lt MOD$ かつ $MOD$ は素数。
def mod_pow(x, a, MOD):
ret = 1
cur = x
while a > 0:
if a & 1:
ret = ret * cur % MOD
cur = cur * cur % MOD
a >>= 1
return ret
def precompute_factorials(n, MOD):
factorials = np.ones(n + 1, dtype=np.int64)
for m in range(2, n + 1):
factorials[m] = factorials[m - 1] * m % MOD
inversions = np.ones(n + 1, dtype=np.int64)
inversions[n] = mod_pow(factorials[n], MOD - 2, MOD)
for m in range(n, 2, -1):
inversions[m - 1] = inversions[m] * m % MOD
return factorials, inversions
Union-Find
union_find_init() で table 配列を生成し、返値をインスタンス番号と見なして各関数に与える。
table の根における値は自グループのサイズを表し、結合の際はサイズが小さい方を大きい方の子とする。
UNIONFIND_TABLE = []
def unionfind_init(n):
UNIONFIND_TABLE.append(np.full(n, -1, dtype=np.int64))
return len(UNIONFIND_TABLE) - 1
def unionfind_getroot(ins, x):
table = UNIONFIND_TABLE[ins]
stack = []
while table[x] >= 0:
stack.append(x)
x = table[x]
for y in stack:
table[y] = x
return x
def unionfind_unite(ins, x, y):
table = UNIONFIND_TABLE[ins]
r1 = unionfind_getroot(ins, x)
r2 = unionfind_getroot(ins, y)
if r1 == r2:
return
d1 = table[r1]
d2 = table[r2]
if d1 <= d2:
table[r2] = r1
table[r1] += d2
else:
table[r1] = r2
table[r2] += d1
def unionfind_find(ins, x, y):
return unionfind_getroot(ins, x) == unionfind_getroot(ins, y)
def unionfind_getsize(ins, x):
table = UNIONFIND_TABLE[ins]
return -table[unionfind_getroot(ins, x)]
Binary Indexed Tree (Fenwick Tree)
fenwick_init() に要素数を与えて初期化し、返値をインスタンス番号と見なして各関数に与える。
$i$ は $1~N$ の値を取るものとする(0始まりではない)。
lower_boundは、累積和が $x$ 以上になる最小の $i$ を返す。
使わない場合は、FENWICK_LOGN および fenwick_init 内でのそれを求める処理は不要(残しても大した計算量ではないが)。
FENWICK_TREE = []
FENWICK_LOGN = []
def fenwick_init(n):
log_n = 0
m = n
while m:
log_n += 1
m >>= 1
FENWICK_TREE.append(np.zeros(n + 1, dtype=np.int64))
FENWICK_LOGN.append(log_n)
return len(FENWICK_TREE) - 1
def fenwick_add(ins, i, x):
arr = FENWICK_TREE[ins]
n = arr.size - 1
while i <= n:
arr[i] += x
i += i & -i
def fenwick_sum(ins, i):
arr = FENWICK_TREE[ins]
result = 0
while i > 0:
result += arr[i]
i ^= i & -i
return result
def fenwick_lower_bound(ins, x):
arr = FENWICK_TREE[ins]
log_n = FENWICK_LOGN[ins]
n = arr.size - 1
sum_ = 0
pos = 0
for i in range(log_n, -1, -1):
k = pos + (1 << i)
if k < n and sum_ + arr[k] < x:
sum_ += arr[k]
pos += 1 << i
return pos + 1
外部注入できる Fenwick Tree
単位元と演算を外部から注入する版。
ただし、型があまり自由すぎると扱いきれないので、単位元 identity_element の型は np.int64 型固定とし、演算関数 func も「np.int64型の引数を2つとって、1つ返す関数」固定とする。
func は、add, min, xor など operatorモジュールにあるものはそのまま使えるし、自分で定義したものでもよい。
最大流(Dinic法)
辺に容量 $cap_e$ が決められた有向グラフで、頂点 $s$ から $t$ に流せる最大流量を求める。
二部グラフのマッチングにも使える。
頂点番号は $0~N-1$。
基本は、dinic_init で初期化→dinic_add_links でグラフ生成→dinic_maximum_flow で最大流量を計算。
DINIC_LINKS = []
def dinic_init(n):
lst = [[0]]
lst.clear()
DINIC_LINKS.append([lst.copy() for _ in range(n)])
return len(DINIC_LINKS) - 1
def dinic_add_link(ins, frm, to, cap):
links = DINIC_LINKS[ins]
links[frm].append([to, cap, len(links[to])])
links[to].append([frm, 0, len(links[frm]) - 1])
def dinic_bfs(ins, n, s):
links = DINIC_LINKS[ins]
depth = np.full(n, -1, dtype=np.int64)
depth[s] = 0
deq = np.zeros(n + 5, dtype=np.int64)
dl, dr = 0, 1
deq[0] = s
while dl < dr:
v = deq[dl]
dl += 1
for link in links[v]:
if link[1] > 0 and depth[link[0]] == -1:
depth[link[0]] = depth[v] + 1
deq[dr] = link[0]
dr += 1
return depth
def dinic_dfs(ins, depth, progress, s, t):
links = DINIC_LINKS[ins]
stack = [(s, 10 ** 18)]
flow = 0
while stack:
v, f = stack.pop()
if v == t:
flow = f
continue
if flow == 0:
i = progress[v]
if i == len(links[v]):
continue
progress[v] += 1
stack.append((v, f))
to, cap, rev = links[v][i]
if cap == 0 or depth[v] >= depth[to]:
continue
stack.append((to, min(f, cap)))
else:
i = progress[v] - 1
link = links[v][i]
link[1] -= flow
links[link[0]][link[2]][1] += flow
return flow
def dinic_maximum_flow(ins, n, s, t):
flow = 0
while True:
depth = dinic_bfs(ins, n, s)
if depth[t] == -1:
return flow
progress = np.zeros(n, dtype=np.int64)
path_flow = dinic_dfs(ins, depth, progress, s, t)
while path_flow != 0:
flow += path_flow
path_flow = dinic_dfs(ins, depth, progress, s, t)
最小費用流
辺に容量 $cap_e$ と、1単位を流したときのコスト $cost_e$ が決められた有向グラフで、頂点 $s$ から $t$ に、流量 $Q$ を流した時の最小コストを求める。
頂点番号は $0~N-1$。
基本的な使い方は、mincostflow_init で初期化→mincostflow_add_links でグラフ生成→mincostflow_flow で最小費用を計算。
- 実装に当たっての参考
from heapq import heappop, heappush
MINCOSTFLOW_LINKS = []
INF = 10 ** 10
def mincostflow_init(n):
""" n: 頂点数 """
lst = [[0]]
lst.clear()
MINCOSTFLOW_LINKS.append([lst.copy() for _ in range(n)])
return len(MINCOSTFLOW_LINKS) - 1
def mincostflow_add_link(ins, frm, to, capacity, cost):
""" インスタンスID, 辺始点頂点番号, 辺終点頂点番号, 容量, コスト """
links = MINCOSTFLOW_LINKS[ins]
links[frm].append([to, capacity, cost, len(links[to])])
links[to].append([frm, 0, -cost, len(links[frm]) - 1])
def mincostflow_flow(ins, s, t, quantity):
""" インスタンスID, フロー始点頂点番号, フロー終点頂点番号, 要求流量 """
links = MINCOSTFLOW_LINKS[ins]
n = len(links)
res = 0
potentials = np.zeros(n, dtype=np.int64)
dist = np.full(n, INF, dtype=np.int64)
prev_v = np.full(n, -1, dtype=np.int64)
prev_e = np.full(n, -1, dtype=np.int64)
while quantity:
dist.fill(INF)
dist[s] = 0
que = [(0, s)]
while que:
total_cost, v = heappop(que)
if dist[v] < total_cost:
continue
for i, (u, cap, cost, _) in enumerate(links[v]):
new_cost = dist[v] + potentials[v] - potentials[u] + cost
if cap > 0 and new_cost < dist[u]:
dist[u] = new_cost
prev_v[u] = v
prev_e[u] = i
heappush(que, (new_cost, u))
# Cannot flow quantity
if dist[t] == INF:
return -1
potentials += dist
cur_flow = quantity
v = t
while v != s:
cur_flow = min(cur_flow, links[prev_v[v]][prev_e[v]][1])
v = prev_v[v]
quantity -= cur_flow
res += cur_flow * potentials[t]
v = t
while v != s:
link = links[prev_v[v]][prev_e[v]]
link[1] -= cur_flow
links[v][link[3]][1] += cur_flow
v = prev_v[v]
return res
留意点
同じNumpy配列同士を演算するとエラー
AtCoderで使われて いる いた過去の Numba 0.48.0 では、同じNumPy配列同士を演算すると(?)エラーになる。(※詳細な条件はちゃんと調べてない)
0.53では修正されているのを確認している。
一方を別の名前の変数で定義してやると大丈夫になる。
# エラー(配列 x の各値をそれぞれ a 乗するコード)
@njit
def mod_pow(x, a):
ret = np.ones_like(x)
cur = x
while a > 0:
if a & 1:
ret = ret * cur % MOD
cur = cur * cur % MOD # ←エラー
a >>= 1
return ret
# おっけー
def mod_pow(x, a):
ret = np.ones_like(x)
cur = x
while a > 0:
if a & 1:
ret = ret * cur % MOD
cur_ = cur
cur = cur * cur_ % MOD
a >>= 1
return ret
整数が0か0以外かの判定はちゃんと書く
Numba 0.57.0 で確認。
Pythonでは、整数が0か0以外かの判定に「if a:」「while a:」などと書いても解釈してくれるが、Numbaではbool値として解釈されてしまうことがある。
その場合、コンパイルされた関数中の $a$ は全てbool値となるので、正整数を渡しても強制的に 1 になるなど、おかしくなる。(詳細な条件は不明)
@njit
def main():
def mod_pow(x, a):
ret = 1
cur = x
while a: # ← ここの記述からか、a は関数全体を通してbool値として扱われる
if a & 1:
ret = ret * cur
cur = cur * cur
a >>= 1
return ret
print(mod_pow(10, 5)) # 10^1 として渡ってしまい、10 が返る
「while a > 0:」など、明示的にint型であることが分かるような書き方をする必要がある。

