Numba Library

競技プログラミングのNumba用スニペットを、作成するたびに追記していく。

基本的には素のPythonとあまり変わらず、せいぜいListをnumpy.ndarrayに置きかえただけのものがほとんどだが、 普通に実装するとNumbaでは使えない機能を踏んでしまうものも一部あり、それを避けた実装としてNumba用にまとまってた方が嬉しい。

  • classで作ると事前コンパイルしにくい(できない?)ので、複数の関数を用意する形にする
  • クロージャとして実装(@njit する大枠の関数の中に記述)すれば、各関数を逐一 @njit する必要は無い
  • Numbaのクロージャは再帰ができないので、再帰を用いた実装はなるべく避ける
    • どうしても再帰で書かざるを得ない or 再帰の方がアレンジしやすい場合は、単独の関数としてコンパイルする
      • その場合は @njit(型名) を並記する

クラス変数の代替方法

Numbaでは、事前コンパイルをする場合はクラスが使えないのが悩ましい。
おかげでインスタンス変数、つまりは状態を持てないので、常に外部から注入する必要がある。

使う際に何を注入するか意識しなければならないし、管理したい変数が増えると記述もどんどん冗長になる。


class UnionFind:
    def __init(self, n):
        self.table = [-1] * n  # ←インスタンス変数
    
    def unite(self, a, b):
        # ...略

# 使う際にはtableは隠蔽され、中でどう使われているかなんて気にしなくてよい

uft = UnionFind(10)
uft.unite(1, 5)
uft.unite(2, 6)

@njit
def main():
    def unite(table, a, b):
        # 略
    
    def find(table, a, b):
        # 略
    
    # 使う際は、外部でtableを定義し、毎回連れ回す必要が生じる
    
    table = [-1] * 10
    unite(table, 1, 5)
    unite(table, 2, 6)

また、もう1つの問題点として、Numbaは内部関数も含め、nested(入れ子)なリスト等を引数に取れない。Numpyの多次元配列ならOKだが、それでは表現できないものもある。

(上手くいくこともある? 条件調査中)

@njit
def main():
    def something_function(nested_list, a, b):
        # コンパイル時エラー
    
    nested_list = [[0] for _ in range(10)]
    something_function(nested_list, 1, 5)

Numba関数の中で、nestedなリストを作ることはできる。また、同じ関数内にクロージャ関数を作れば、クロージャ関数から関数外のリストを参照することができる。

これを用いて、以下のようにすれば、something_function の中で nested_list が使える。

@njit
def main():

    nested_list = [[0] for _ in range(10)]  # 関数より先にnonlocalな変数を定義

    def something_function(a, b):
        nested_list[a][b] = 1               # 変数を使う
    
    something_function(1, 5)

しかし、それでは関数と状態が1対1で結びついてしまい、クラスにおける「複数のインスタンスを作る」ようなことができない。

あまり綺麗ではないが、無理矢理解決するとしたら、以下のようになるだろうか。

関数外部にはインスタンス別のリストを記録する NESTED_LIST を用意し、init() ではそこに初期化したリストを加える(これが1つのインスタンス変数となる)。
init() は自身のIDを返すので、以降、something_function() など他の関数を呼ぶ際は、そのIDのみを連れ回す。

これなら、管理したい変数が増えても使う側で管理するのはIDのみで済み、極力、内部実装を意識しないで使える。

ちなみに、init()では、入れ子リストの中身が空だと型推定ができず、コンパイルエラーとなってしまう(9~14行目)。
ちょっと奇妙だが、入れる予定の型が分かるような書き方で lst を定義し、それを空にした後コピーするようにすると上手くいく。

@njit
def main():

    NESTED_LIST = []
    
    def init(n):
        _id = len(NESTED_LIST)
        
        # × コンパイルエラー
        # NESTED_LIST.append([[] for _ in range(n)])
        
        lst = [0]
        lst.clear()
        NESTED_LIST.append([lst.copy() for _ in range(n)])
        return _id

    def something_function(_id, a, b):
        nested_list = NESTED_LIST[_id]
        nested_list[a][b] = 1
    
    id1 = something_init(10)
    id2 = something_init(20)
    
    something_function(id1, 1, 5)
    something_function(id2, 2, 6)

あくまでコンパイルが通るというだけで、もっといい書き方があるなら使いたい。

実装例

bit count

2進数表記で'1'の立っている数。


def bit_count(x):
    x = (x & 0x55555555) + ((x >>  1) & 0x55555555)
    x = (x & 0x33333333) + ((x >>  2) & 0x33333333)
    x = (x & 0x0F0F0F0F) + ((x >>  4) & 0x0F0F0F0F)
    x = (x & 0x00FF00FF) + ((x >>  8) & 0x00FF00FF)
    x = (x & 0x0000FFFF) + ((x >> 16) & 0x0000FFFF)
    return x

bit length

2進数表記の桁数(0は0)

    def bit_length(n):
        ret = 0
        while n:
            n >>= 1
            ret += 1
        return ret

mod累乗

$x^a$ を MOD で割った剰余。pythonなら pow(x, a, MOD) で求められるが、Numbaでは第3引数が未対応。二分累乗法で実装。


def mod_pow(x, a, MOD):
    ret = 1
    cur = x
    while a:
        if a & 1:
            ret = ret * cur % MOD
        cur = cur * cur % MOD
        a >>= 1
    return ret

mod階乗と逆数の事前計算

$0!~N!$ とそのモジュラ逆数を計算。上記の mod_pow を使用。
前提として、$n \lt MOD$ かつ $MOD$ は素数。


    def mod_pow(x, a, MOD):
        ret = 1
        cur = x
        while a:
            if a & 1:
                ret = ret * cur % MOD
            cur = cur * cur % MOD
            a >>= 1
        return ret

    def precompute_factorials(n, MOD):
        factorials = np.ones(n + 1, dtype=np.int64)
        for m in range(2, n + 1):
            factorials[m] = factorials[m - 1] * m % MOD
        inversions = np.ones(n + 1, dtype=np.int64)
        inversions[n] = mod_pow(factorials[n], MOD - 2, MOD)
        for m in range(n, 2, -1):
            inversions[m - 1] = inversions[m] * m % MOD
        return factorials, inversions

Union-Find

union_find_init()table 配列を生成し、返値をインスタンス番号と見なして各関数に与える。
table の根における値は自グループのサイズを表し、結合の際はサイズが小さい方を大きい方の子とする。


    UNIONFIND_TABLE = []

    def unionfind_init(n):
        UNIONFIND_TABLE.append(np.full(n, -1, dtype=np.int64))
        return len(UNIONFIND_TABLE) - 1

    def unionfind_getroot(ins, x):
        table = UNIONFIND_TABLE[ins]
        stack = []
        while table[x] >= 0:
            stack.append(x)
            x = table[x]
        for y in stack:
            table[y] = x
        return x

    def unionfind_unite(ins, x, y):
        table = UNIONFIND_TABLE[ins]
        r1 = unionfind_getroot(ins, x)
        r2 = unionfind_getroot(ins, y)
        if r1 == r2:
            return
        d1 = table[r1]
        d2 = table[r2]
        if d1 <= d2:
            table[r2] = r1
            table[r1] += d2
        else:
            table[r1] = r2
            table[r2] += d1

    def unionfind_find(ins, x, y):
        return unionfind_getroot(ins, x) == unionfind_getroot(ins, y)

    def unionfind_getsize(ins, x):
        table = UNIONFIND_TABLE[ins]
        return -table[unionfind_getroot(ins, x)]

Binary Indexed Tree (Fenwick Tree)

fenwick_init() に要素数を与えて初期化し、返値をインスタンス番号と見なして各関数に与える。

$i$ は $1~N$ の値を取るものとする(0始まりではない)。

lower_boundは、累積和が $x$ 以上になる最小の $i$ を返す。
使わない場合は、FENWICK_LOGN および fenwick_init 内でのそれを求める処理は不要(残しても大した計算量ではないが)。


    FENWICK_TREE = []
    FENWICK_LOGN = []

    def fenwick_init(n):
        log_n = 0
        m = n
        while m:
            log_n += 1
            m >>= 1
        FENWICK_TREE.append(np.zeros(n + 1, dtype=np.int64))
        FENWICK_LOGN.append(log_n)
        return len(FENWICK_TREE) - 1

    def fenwick_add(ins, i, x):
        arr = FENWICK_TREE[ins]
        n = arr.size - 1
        while i <= n:
            arr[i] += x
            i += i & -i

    def fenwick_sum(ins, i):
        arr = FENWICK_TREE[ins]
        result = 0
        while i > 0:
            result += arr[i]
            i ^= i & -i
        return result

    def fenwick_lower_bound(ins, x):
        arr = FENWICK_TREE[ins]
        log_n = FENWICK_LOGN[ins]
        n = arr.size - 1
        sum_ = 0
        pos = 0
        for i in range(log_n, -1, -1):
            k = pos + (1 << i)
            if k < n and sum_ + arr[k] < x:
                sum_ += arr[k]
                pos += 1 << i
        return pos + 1

最大流(Dinic法)

辺に容量 $cap_e$ が決められた有向グラフで、頂点 $s$ から $t$ に流せる最大流量を求める。
二部グラフのマッチングにも使える。

頂点番号は $0~N-1$。

基本は、dinic_init で初期化→dinic_add_links でグラフ生成→dinic_maximum_flow で最大流量を計算。

    DINIC_LINKS = []

    def dinic_init(n):
        lst = [[0]]
        lst.clear()
        DINIC_LINKS.append([lst.copy() for _ in range(n)])
        return len(DINIC_LINKS) - 1

    def dinic_add_link(ins, frm, to, cap):
        links = DINIC_LINKS[ins]
        links[frm].append([to, cap, len(links[to])])
        links[to].append([frm, 0, len(links[frm]) - 1])

    def dinic_bfs(ins, n, s):
        links = DINIC_LINKS[ins]
        depth = np.full(n, -1, dtype=np.int64)
        depth[s] = 0
        deq = np.zeros(n + 5, dtype=np.int64)
        dl, dr = 0, 1
        deq[0] = s
        while dl < dr:
            v = deq[dl]
            dl += 1
            for link in links[v]:
                if link[1] > 0 and depth[link[0]] == -1:
                    depth[link[0]] = depth[v] + 1
                    deq[dr] = link[0]
                    dr += 1
        return depth

    def dinic_dfs(ins, depth, progress, s, t):
        links = DINIC_LINKS[ins]
        stack = [(s, 10 ** 18)]
        flow = 0
        while stack:
            v, f = stack.pop()
            if v == t:
                flow = f
                continue
            if flow == 0:
                i = progress[v]
                if i == len(links[v]):
                    continue
                progress[v] += 1
                stack.append((v, f))
                to, cap, rev = links[v][i]
                if cap == 0 or depth[v] >= depth[to]:
                    continue
                stack.append((to, min(f, cap)))
            else:
                i = progress[v] - 1
                link = links[v][i]
                link[1] -= flow
                links[link[0]][link[2]][1] += flow
        return flow

    def dinic_maximum_flow(ins, n, s, t):
        flow = 0
        while True:
            depth = dinic_bfs(ins, n, s)
            if depth[t] == -1:
                return flow
            progress = np.zeros(n, dtype=np.int64)
            path_flow = dinic_dfs(ins, depth, progress, s, t)
            while path_flow != 0:
                flow += path_flow
                path_flow = dinic_dfs(ins, depth, progress, s, t)

最小費用流

辺に容量 $cap_e$ と、1単位を流したときのコスト $cost_e$ が決められた有向グラフで、頂点 $s$ から $t$ に、流量 $Q$ を流した時の最小コストを求める。

頂点番号は $0~N-1$。

基本的な使い方は、mincostflow_init で初期化→mincostflow_add_links でグラフ生成→mincostflow_flow で最小費用を計算。


    from heapq import heappop, heappush

    MINCOSTFLOW_LINKS = []
    INF = 10 ** 10

    def mincostflow_init(n):
        """ n: 頂点数 """
        lst = [[0]]
        lst.clear()
        MINCOSTFLOW_LINKS.append([lst.copy() for _ in range(n)])
        return len(MINCOSTFLOW_LINKS) - 1

    def mincostflow_add_link(ins, frm, to, capacity, cost):
        """ インスタンスID, 辺始点頂点番号, 辺終点頂点番号, 容量, コスト """
        links = MINCOSTFLOW_LINKS[ins]
        links[frm].append([to, capacity, cost, len(links[to])])
        links[to].append([frm, 0, -cost, len(links[frm]) - 1])

    def mincostflow_flow(ins, s, t, quantity):
        """ インスタンスID, フロー始点頂点番号, フロー終点頂点番号, 要求流量 """
        links = MINCOSTFLOW_LINKS[ins]
        n = len(links)
        res = 0
        potentials = np.zeros(n, dtype=np.int64)
        dist = np.full(n, INF, dtype=np.int64)
        prev_v = np.full(n, -1, dtype=np.int64)
        prev_e = np.full(n, -1, dtype=np.int64)

        while quantity:
            dist.fill(INF)
            dist[s] = 0
            que = [(0, s)]

            while que:
                total_cost, v = heappop(que)
                if dist[v] < total_cost:
                    continue
                for i, (u, cap, cost, _) in enumerate(links[v]):
                    new_cost = dist[v] + potentials[v] - potentials[u] + cost
                    if cap > 0 and new_cost < dist[u]:
                        dist[u] = new_cost
                        prev_v[u] = v
                        prev_e[u] = i
                        heappush(que, (new_cost, u))

            # Cannot flow quantity
            if dist[t] == INF:
                return -1

            potentials += dist

            cur_flow = quantity
            v = t
            while v != s:
                cur_flow = min(cur_flow, links[prev_v[v]][prev_e[v]][1])
                v = prev_v[v]
            quantity -= cur_flow
            res += cur_flow * potentials[t]

            v = t
            while v != s:
                link = links[prev_v[v]][prev_e[v]]
                link[1] -= cur_flow
                links[v][link[3]][1] += cur_flow
                v = prev_v[v]

        return res

本WebサイトはcookieをPHPのセッション識別および左欄目次の開閉状況記憶のために使用しています。同意できる方のみご覧ください。More information about cookies
programming_algorithm/python_tips/numba_library.txt · 最終更新: 2020/09/11 by ikatakos
CC Attribution 4.0 International
Driven by DokuWiki Recent changes RSS feed Valid CSS Valid XHTML 1.0