Numba Library
競技プログラミングのNumba用スニペットを、作成するたびに追記していく。
基本的には素のPythonとあまり変わらず、せいぜいListをnumpy.ndarrayに置きかえただけのものがほとんどだが、 普通に実装するとNumbaでは使えない機能を踏んでしまうものも一部あり、それを避けた実装としてNumba用にまとまってた方が嬉しい。
- classで作ると事前コンパイルしにくい(できない?)ので、複数の関数を用意する形にする
- クロージャとして実装(
@njit
する大枠の関数の中に記述)すれば、各関数を逐一@njit
する必要は無い - Numbaのクロージャは再帰ができないので、再帰を用いた実装はなるべく避ける
- どうしても再帰で書かざるを得ない or 再帰の方がアレンジしやすい場合は、単独の関数としてコンパイルする
- その場合は
@njit(型名)
を並記する
クラス変数の代替方法
ちょっと長くなるが、Numba用に移植する際の問題点の1つへの対処法に関する考察を書いておく。
Numbaでは、事前コンパイルをする場合はクラスが使えないのが悩ましい。
おかげでインスタンス変数、つまりは状態を持てないので、常に外部から注入する必要がある。
使う際に何を注入するか意識しなければならないし、管理したい変数が増えると記述もどんどん冗長になる。
1 2 3 4 5 6 7 8 9 10 11 12 |
class UnionFind: def __init( self , n): self .table = [ - 1 ] * n # ←インスタンス変数 def unite( self , a, b): # ...略 # 使う際にはtableは隠蔽され、中でどう使われているかなんて気にしなくてよい uft = UnionFind( 10 ) uft.unite( 1 , 5 ) uft.unite( 2 , 6 ) |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
@njit def main(): def unite(table, a, b): # 略 def find(table, a, b): # 略 # 使う際は、外部でtableを定義し、毎回連れ回す必要が生じる table = [ - 1 ] * 10 unite(table, 1 , 5 ) unite(table, 2 , 6 ) |
また、もう1つの問題点として、Numbaは内部関数も含め、nested(入れ子)なリスト等を引数に取れない。Numpyの多次元配列ならOKだが、それでは表現できないものもある。
(上手くいくこともある? 条件調査中)
1 2 3 4 5 6 7 |
@njit def main(): def something_function(nested_list, a, b): # コンパイル時エラー nested_list = [[ 0 ] for _ in range ( 10 )] something_function(nested_list, 1 , 5 ) |
Numba関数の中で、nestedなリストを作ることはできる。また、同じ関数内にクロージャ関数を作れば、クロージャ関数から関数外のリストを参照することができる。
これを用いて、以下のようにすれば、something_function
の中で nested_list
が使える。
1 2 3 4 5 6 7 8 9 |
@njit def main(): nested_list = [[ 0 ] for _ in range ( 10 )] # 関数より先にnonlocalな変数を定義 def something_function(a, b): nested_list[a][b] = 1 # 変数を使う something_function( 1 , 5 ) |
しかし、それでは関数と状態が1対1で結びついてしまい、クラスにおける「複数のインスタンスを作る」ようなことができない。
あまり綺麗ではないが、無理矢理解決するとしたら、以下のようになるだろうか。
関数外部にはインスタンス別のリストを記録する NESTED_LIST
を用意し、init()
ではそこに初期化したリストを加える(これが1つのインスタンス変数となる)。
init()
は自身のIDを返すので、以降、something_function()
など他の関数を呼ぶ際は、そのIDのみを連れ回す。
これなら、管理したい変数が増えても使う側で管理するのはIDのみで済み、極力、内部実装を意識しないで使える。
ちなみに、init()
では、入れ子リストの中身が空だと型推定ができず、コンパイルエラーとなってしまう(9~14行目)。
ちょっと奇妙だが、入れる予定の型が分かるような書き方で lst
を定義し、それを空にした後コピーするようにすると上手くいく。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
@njit def main(): NESTED_LIST = [] def init(n): _id = len (NESTED_LIST) # × コンパイルエラー # NESTED_LIST.append([[] for _ in range(n)]) lst = [ 0 ] lst.clear() NESTED_LIST.append([lst.copy() for _ in range (n)]) return _id def something_function(_id, a, b): nested_list = NESTED_LIST[_id] nested_list[a][b] = 1 id1 = something_init( 10 ) id2 = something_init( 20 ) something_function(id1, 1 , 5 ) something_function(id2, 2 , 6 ) |
あくまでコンパイルが通るというだけで、もっといい書き方があるなら使いたい。
実装例
bit count
2進数表記で'1
'の立っている数。
1 2 3 4 5 6 7 |
def bit_count(x): x = (x & 0x55555555 ) + ((x >> 1 ) & 0x55555555 ) x = (x & 0x33333333 ) + ((x >> 2 ) & 0x33333333 ) x = (x & 0x0F0F0F0F ) + ((x >> 4 ) & 0x0F0F0F0F ) x = (x & 0x00FF00FF ) + ((x >> 8 ) & 0x00FF00FF ) x = (x & 0x0000FFFF ) + ((x >> 16 ) & 0x0000FFFF ) return x |
bit length
2進数表記の桁数(0は0)
1 2 3 4 5 6 |
def bit_length(n): ret = 0 while n > 0 : n >> = 1 ret + = 1 return ret |
なお、n が1以上かを確認するのに while n:
としても素のPythonは通るが、Numbaではバージョンにより n をbool値だと推定してコンパイルしてしまう。
そうなると、引数にどんな正整数を渡しても n=1 となってしまい、おかしくなる。ちゃんと int 型であることがわかるような書き方をする。
mod累乗
xa を MOD
で割った剰余。pythonなら pow(x, a, MOD)
で求められるが、Numbaでは第3引数が未対応。二分累乗法で実装。
1 2 3 4 5 6 7 8 9 |
def mod_pow(x, a, MOD): ret = 1 cur = x while a > 0 : if a & 1 : ret = ret * cur % MOD cur = cur * cur % MOD a >> = 1 return ret |
mod階乗と逆数の事前計算
0!~N! とそのモジュラ逆数を計算。上記の mod_pow
を使用。
前提として、n<MOD かつ MOD は素数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
def mod_pow(x, a, MOD): ret = 1 cur = x while a > 0 : if a & 1 : ret = ret * cur % MOD cur = cur * cur % MOD a >> = 1 return ret def precompute_factorials(n, MOD): factorials = np.ones(n + 1 , dtype = np.int64) for m in range ( 2 , n + 1 ): factorials[m] = factorials[m - 1 ] * m % MOD inversions = np.ones(n + 1 , dtype = np.int64) inversions[n] = mod_pow(factorials[n], MOD - 2 , MOD) for m in range (n, 2 , - 1 ): inversions[m - 1 ] = inversions[m] * m % MOD return factorials, inversions |
Union-Find
union_find_init()
で table
配列を生成し、返値をインスタンス番号と見なして各関数に与える。
table
の根における値は自グループのサイズを表し、結合の際はサイズが小さい方を大きい方の子とする。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
UNIONFIND_TABLE = [] def unionfind_init(n): UNIONFIND_TABLE.append(np.full(n, - 1 , dtype = np.int64)) return len (UNIONFIND_TABLE) - 1 def unionfind_getroot(ins, x): table = UNIONFIND_TABLE[ins] stack = [] while table[x] > = 0 : stack.append(x) x = table[x] for y in stack: table[y] = x return x def unionfind_unite(ins, x, y): table = UNIONFIND_TABLE[ins] r1 = unionfind_getroot(ins, x) r2 = unionfind_getroot(ins, y) if r1 = = r2: return d1 = table[r1] d2 = table[r2] if d1 < = d2: table[r2] = r1 table[r1] + = d2 else : table[r1] = r2 table[r2] + = d1 def unionfind_find(ins, x, y): return unionfind_getroot(ins, x) = = unionfind_getroot(ins, y) def unionfind_getsize(ins, x): table = UNIONFIND_TABLE[ins] return - table[unionfind_getroot(ins, x)] |
Binary Indexed Tree (Fenwick Tree)
fenwick_init()
に要素数を与えて初期化し、返値をインスタンス番号と見なして各関数に与える。
i は 1~N の値を取るものとする(0始まりではない)。
lower_boundは、累積和が x 以上になる最小の i を返す。
使わない場合は、FENWICK_LOGN
および fenwick_init
内でのそれを求める処理は不要(残しても大した計算量ではないが)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
FENWICK_TREE = [] FENWICK_LOGN = [] def fenwick_init(n): log_n = 0 m = n while m: log_n + = 1 m >> = 1 FENWICK_TREE.append(np.zeros(n + 1 , dtype = np.int64)) FENWICK_LOGN.append(log_n) return len (FENWICK_TREE) - 1 def fenwick_add(ins, i, x): arr = FENWICK_TREE[ins] n = arr.size - 1 while i < = n: arr[i] + = x i + = i & - i def fenwick_sum(ins, i): arr = FENWICK_TREE[ins] result = 0 while i > 0 : result + = arr[i] i ^ = i & - i return result def fenwick_lower_bound(ins, x): arr = FENWICK_TREE[ins] log_n = FENWICK_LOGN[ins] n = arr.size - 1 sum_ = 0 pos = 0 for i in range (log_n, - 1 , - 1 ): k = pos + ( 1 << i) if k < n and sum_ + arr[k] < x: sum_ + = arr[k] pos + = 1 << i return pos + 1 |
外部注入できる Fenwick Tree
単位元と演算を外部から注入する版。
ただし、型があまり自由すぎると扱いきれないので、単位元 identity_element
の型は np.int64
型固定とし、演算関数 func
も「np.int64型の引数を2つとって、1つ返す関数」固定とする。
func
は、add, min, xor
など operatorモジュールにあるものはそのまま使えるし、自分で定義したものでもよい。
最大流(Dinic法)
辺に容量 cape が決められた有向グラフで、頂点 s から t に流せる最大流量を求める。
二部グラフのマッチングにも使える。
頂点番号は 0~N−1。
基本は、dinic_init
で初期化→dinic_add_links
でグラフ生成→dinic_maximum_flow
で最大流量を計算。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
DINIC_LINKS = [] def dinic_init(n): lst = [[ 0 ]] lst.clear() DINIC_LINKS.append([lst.copy() for _ in range (n)]) return len (DINIC_LINKS) - 1 def dinic_add_link(ins, frm, to, cap): links = DINIC_LINKS[ins] links[frm].append([to, cap, len (links[to])]) links[to].append([frm, 0 , len (links[frm]) - 1 ]) def dinic_bfs(ins, n, s): links = DINIC_LINKS[ins] depth = np.full(n, - 1 , dtype = np.int64) depth[s] = 0 deq = np.zeros(n + 5 , dtype = np.int64) dl, dr = 0 , 1 deq[ 0 ] = s while dl < dr: v = deq[dl] dl + = 1 for link in links[v]: if link[ 1 ] > 0 and depth[link[ 0 ]] = = - 1 : depth[link[ 0 ]] = depth[v] + 1 deq[dr] = link[ 0 ] dr + = 1 return depth def dinic_dfs(ins, depth, progress, s, t): links = DINIC_LINKS[ins] stack = [(s, 10 * * 18 )] flow = 0 while stack: v, f = stack.pop() if v = = t: flow = f continue if flow = = 0 : i = progress[v] if i = = len (links[v]): continue progress[v] + = 1 stack.append((v, f)) to, cap, rev = links[v][i] if cap = = 0 or depth[v] > = depth[to]: continue stack.append((to, min (f, cap))) else : i = progress[v] - 1 link = links[v][i] link[ 1 ] - = flow links[link[ 0 ]][link[ 2 ]][ 1 ] + = flow return flow def dinic_maximum_flow(ins, n, s, t): flow = 0 while True : depth = dinic_bfs(ins, n, s) if depth[t] = = - 1 : return flow progress = np.zeros(n, dtype = np.int64) path_flow = dinic_dfs(ins, depth, progress, s, t) while path_flow ! = 0 : flow + = path_flow path_flow = dinic_dfs(ins, depth, progress, s, t) |
最小費用流
辺に容量 cape と、1単位を流したときのコスト coste が決められた有向グラフで、頂点 s から t に、流量 Q を流した時の最小コストを求める。
頂点番号は 0~N−1。
基本的な使い方は、mincostflow_init
で初期化→mincostflow_add_links
でグラフ生成→mincostflow_flow
で最小費用を計算。
- 実装に当たっての参考
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
from heapq import heappop, heappush MINCOSTFLOW_LINKS = [] INF = 10 * * 10 def mincostflow_init(n): """ n: 頂点数 """ lst = [[ 0 ]] lst.clear() MINCOSTFLOW_LINKS.append([lst.copy() for _ in range (n)]) return len (MINCOSTFLOW_LINKS) - 1 def mincostflow_add_link(ins, frm, to, capacity, cost): """ インスタンスID, 辺始点頂点番号, 辺終点頂点番号, 容量, コスト """ links = MINCOSTFLOW_LINKS[ins] links[frm].append([to, capacity, cost, len (links[to])]) links[to].append([frm, 0 , - cost, len (links[frm]) - 1 ]) def mincostflow_flow(ins, s, t, quantity): """ インスタンスID, フロー始点頂点番号, フロー終点頂点番号, 要求流量 """ links = MINCOSTFLOW_LINKS[ins] n = len (links) res = 0 potentials = np.zeros(n, dtype = np.int64) dist = np.full(n, INF, dtype = np.int64) prev_v = np.full(n, - 1 , dtype = np.int64) prev_e = np.full(n, - 1 , dtype = np.int64) while quantity: dist.fill(INF) dist[s] = 0 que = [( 0 , s)] while que: total_cost, v = heappop(que) if dist[v] < total_cost: continue for i, (u, cap, cost, _) in enumerate (links[v]): new_cost = dist[v] + potentials[v] - potentials[u] + cost if cap > 0 and new_cost < dist[u]: dist[u] = new_cost prev_v[u] = v prev_e[u] = i heappush(que, (new_cost, u)) # Cannot flow quantity if dist[t] = = INF: return - 1 potentials + = dist cur_flow = quantity v = t while v ! = s: cur_flow = min (cur_flow, links[prev_v[v]][prev_e[v]][ 1 ]) v = prev_v[v] quantity - = cur_flow res + = cur_flow * potentials[t] v = t while v ! = s: link = links[prev_v[v]][prev_e[v]] link[ 1 ] - = cur_flow links[v][link[ 3 ]][ 1 ] + = cur_flow v = prev_v[v] return res |
留意点
同じNumpy配列同士を演算するとエラー
AtCoderで使われて いる いた過去の Numba 0.48.0 では、同じNumPy配列同士を演算すると(?)エラーになる。(※詳細な条件はちゃんと調べてない)
0.53では修正されているのを確認している。
一方を別の名前の変数で定義してやると大丈夫になる。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
# エラー(配列 x の各値をそれぞれ a 乗するコード) @njit def mod_pow(x, a): ret = np.ones_like(x) cur = x while a > 0 : if a & 1 : ret = ret * cur % MOD cur = cur * cur % MOD # ←エラー a >> = 1 return ret # おっけー def mod_pow(x, a): ret = np.ones_like(x) cur = x while a > 0 : if a & 1 : ret = ret * cur % MOD cur_ = cur cur = cur * cur_ % MOD a >> = 1 return ret |
整数が0か0以外かの判定はちゃんと書く
Numba 0.57.0 で確認。
Pythonでは、整数が0か0以外かの判定に「if a:
」「while a:
」などと書いても解釈してくれるが、Numbaではbool値として解釈されてしまうことがある。
その場合、コンパイルされた関数中の a は全てbool値となるので、正整数を渡しても強制的に 1 になるなど、おかしくなる。(詳細な条件は不明)
1 2 3 4 5 6 7 8 9 10 11 12 13 |
@njit def main(): def mod_pow(x, a): ret = 1 cur = x while a: # ← ここの記述からか、a は関数全体を通してbool値として扱われる if a & 1 : ret = ret * cur cur = cur * cur a >> = 1 return ret print (mod_pow( 10 , 5 )) # 10^1 として渡ってしまい、10 が返る |
「while a > 0:
」など、明示的にint型であることが分かるような書き方をする必要がある。