Numba (Python)

Pythonを高速に動くようにするモジュール。やってることは、LLVMでの機械語へのコンパイル。

「既存のPythonコードにあまり手を加えず、速くなればいいな程度でとりあえず軽く使う」こともできるし、「きちんとコードを書き換えて高速化する」こともできる。

他手段との比較

高速化の手段としては、他にPyPy、Cythonなどがある。Cythonは詳しく知らないが、PyPyと比較してはざっくり以下のようなイメージ。

  • PyPy
    • ほぼ書き換えなしでそれなりに高速化される 1)
    • Pythonとは別の実行環境としてインストールが必要
    • 基本的に実行時コンパイル(JIT)
      • なので、実行時間に常にコンパイル時間が含まれる
    • コンパイル時間は短め
  • Numba
    • 単純な関数ならともかく、通常は書き換え無しではそこまで高速化されない
      • 恩恵を受けようと思うと、素Pythonでは却って遅くなるような書き方をすることもある
      • ただしCythonのように素Pythonで動かなくなるようなコードにすることはない
    • Pythonのモジュールとして、pipなどでインストールできる
    • JITも出来るし、事前コンパイル(AOT)もできる
    • コンパイル時間はPyPyと比較すると長め
    • 効果は高い(勿論、書き方や得意分野による)

速度比較は、参考文献の1つめが詳しい。

PyPyの方が、労力が少なく効果がそこそこでコスパが良い感はある。

Numbaは、後述のNoPythonモードによって高速化できないコードをエラーにすることが出来るので、着実に高速化させやすい。

また、既存コードに未対応の外部モジュールを使っていると、PyPyなら丸ごと実行できなくなる一方 2)、Numbaでは関数を切り分けて、使っていない部分だけコンパイルさせることができる。

参考文献

環境

  • Python 3.8.2
  • Numba 0.47.0

Objectモード NoPythonモード

高速化する上で重要な概念。

基本的にNumbaは関数のデコレータに「@jit」を付けるだけで実行時コンパイラが動くが、どんな型でもちゃんとコンパイルできるわけではない。 未対応の関数やデータ型が含まれる場合、「Objectモード」でコンパイルされる。Objectモードだと、高速化の効果は少ない。

基本的に何が出来て何が出来ないかは、下記を見ればよい。

そうでなく、全てがNumbaが対応する機能のみで記述されていれば、「NoPythonモード」となり、高速化の効果が大きくなる。

デコレータとして @njit を使うと、NoPythonモードを強制し、未対応の型があると警告やエラーを出してくれる。 一方、@jit だとObjectモードでも許容する。

既存コードに手を加えたくない場合は @jit、高速化目的なら @njit を使うとよい。(ただ、Objectモードはそのうちdeprecatedになるかも?)

NoPythonモードで未対応の型(PythonのList)を使用した際の警告例

NumbaPendingDeprecationWarning: 
Encountered the use of a type that is scheduled for deprecation: type 'reflected list' found for argument 'a' of function 'function'.

以下、基本的にNoPythonモードを満たすように記述することを目標とする。

使える型

通常のPythonから大きく制限される。 数値型とnumpy.ndarrayしか使えないくらいに思っておいてよい。

一応SetやDictも使える。(ただしnumpy.ndarray等と比較して動作は速くない)

引数や戻り値のような、Numba関数の中と外の橋渡し的な役割をする変数は、特に留意が必要となる。

引数・戻り値に使える主な型

  • 数値(byte, int, float, complex)
  • numpy.ndarray
  • UniTuple, Tuple
  • (*1) list, set(reflected list, reflected set)
  • (*2) Unicode文字列(Python3における通常の文字列)
  • (*2) numba.typed.List, numba.typed.Dict
  • コンパイル済み関数
    • 引数としてのみ使用可

その他もあるかも知れないけど、どこ見れば書いてあるのかよくわかってない。 なるべく上の3つのみを使うようにした方がよさそう。(個々の問題点は後述)

(*1)はdeprecatedで今後使えなくなる可能性が高い。

(*2)は、対応はしているが、まだ十分高速に動くようコンパイル出来ない場合があるとリファレンスに書かれている。

UniTuple, Tuple

タプルは、複数の要素をまとめることが出来る。要素の型が全て統一されているかどうかで区別される。

  • UniTuple: homogeneous = 統一されている
  • Tuple: heterogeneous = 複数の型が混在している

heterogeneousなTupleは、変数によるindexアクセスなどいくつかの機能が使えない(型を特定できないのでそりゃそう)。 また、イテレートするときに特殊な書き方が必要となる。

引数や戻り値のために一瞬使うだけなら大した影響はないだろうが、内部でガッツリ使う場合はなるべくUniTupleな構造にした方がよい。

型指定の書き方は、以下のようにする。

@njit('UniTuple(i8, 5)()')    ... 64bit整数5個のUniTupleを返す

@njit('Tuple(i8, string, f4)()')   ... (64bit整数,文字列,32bit小数)のTupleを返す
reflected list と numba.typed.List

どちらも、Pythonのlist機能をNumbaで出来る範囲で表現することを目的としたデータ構造。

  • reflected list: Pythonのlistをそのまま受け取ってNumba内で解釈できるようにしたもの
  • typed list: Numba独自のリスト構造クラス

reflected list は deprecated。下記を読むと、ネスト等で複雑になると限界があるので typed listに置きかえていく方針らしい。

そのため、reflected list は(問題なく使えるものの)ver.0.45以降では警告が出る。

また、typed listの方は「実験的機能」とされていて、バグがあったり、高速化の恩恵が少なくなる可能性が言及されている。

両者細かい注意点はあるが、現状、過渡期なこともありどちらも中途半端な状態。 今後の進化に期待しつつ、今のところ配列に関しては numpy.ndarray でいい気がする。

違いとしては、以下のようなものがある。 いずれも、要素の型は統一されている必要があり、統一できないような組合せで初期化・追加したりするとエラーとなる。

  • reflected list
    • 外部から与える際は、ネストは不可。要素の型は統一されている必要がある
    • Numba関数内部では、ネスト可能
    • Numba関数内部で「a = [0, 1]」などとすると、こちらになる
  • typed.List
    • ネスト可
    • numba.typed.List() で空のインスタンスを生成し、1つずつappendする

また、Pythonのデータ構造にはListの他にSet, Dictがあるが、

  • setについては、reflected set はあるが、typed set は未実装
  • dictについては逆に、typed dict はあるが、reflected dict は存在しない

Pythonのdictは使えないため、辞書構造を引数にしたければ numba.typed.Dict で生成して使うしかない。 ただこれも「実験的機能」であり、実際、動作はちょっと遅い。

typed.Listの初期化方法例

コンパイル済み関数

引数としてのみ使え、戻り値には出来ない。

また、関数を引数とする関数を事前コンパイルする方法がわからない。感触的には無理っぽい。

まず、型指定の記述方法が不明。型指定しないと事前コンパイルは出来ない。

一応、型指定しない @njit 関数に対し実際に引数を与えて呼び出すと その関数.nopython_signatures に、呼び出した引数・戻り値の型指定がList形式で追加される。 これを他の関数の型指定に用いることが出来る。

@njit('i8(i8)')
def double(a):
    return 2 * a

@njit
def fumidai(func, a):
    return func(a)

fumidai(double, 2)

@njit(fumidai.nopython_signatures[0])
def hontai(func, a):
    return func(a)

hontai(double, 3)

しかし、

  • hontai@njit(cache=True) でキャッシュしようとすると「TypeError: cannot pickle 'weakref' object」が発生する
  • double以外の関数を定義して渡すと「TypeError: No matching definition for argument type(s) type(CPUDispatcher(<function 関数名 at 0x……))」が発生する

実体のあるオブジェクトでなく参照として渡しているので、関数を置きかえたり、ファイルとして残すことはできないらしい。

有効に使える場面は限られるか。

関数内部での型

list, set は、通常のPython表記 [0, 1], {2, 3} のように書くと reflected list, reflected set となる。

dict は {2: 10, 3: 15} のように書くと typed dict となる。

いずれも、初期化時・追加時に型が混在しているとエラー。

特にdictは、イテレータのように渡されるのか、最初のkey-valueの組合せで型が確定する。例えば、int→floatはキャストできるが、float→intはできないため、以下のようになる。

d = {2: 2.0, 3: 3}  はOK   (key, value) = (int64, float64)

d = {2: 2, 3: 3.0}  はエラー

ただ、使ってみた感触としてはこれらのパフォーマンスはあまり優れているとは言えない。実行時間が数倍遅くなることもある。なるべくならNumpy配列を使った方がよい。

また、Set, Dictに関しては、リスト内包表記で生成することは出来ない。(イテレート元とすることは出来る)

@njit
def function():
    a = [0, 1, 2]         # int64型のreflected list
    b = {0, 1, 1.5}       # float64型のreflected set
    c = {1: 1.5, 2: 2.5}  # (int64, float64)型のtyped dict
    d = {1: 1, 2: 2.5}    # エラー
    
    g = [v + 1 for v in a]             # [1, 2, 3]
    h = [v + 1 for v in b]             # [1.0, 2.0, 3.0]
    i = [k + v for k, v in c.items()]  # [2.5, 4.5]
    j = {v + 1 for v in a}             # エラー
    k = {v: v+1 for v in a}            # エラー

使える関数

ビルトイン関数

通常の演算処理で使うものはほぼ使えるが、たまに一部のオプションが未実装だったり。

基本はリファレンス参照。注意すべきいくつかの関数は以下。

  • pow(x, a, mod)
    • mod で返す機能は未実装
    • 繰り返し二乗法などで自力実装する
  • List.sort(), sorted()
    • key は指定できない
  • int(n, base)
    • 何進数かを指定する機能は未実装

標準モジュール

よく使うもので関係ありそうなのは以下。

    • NamedTuple以外は未実装
    • deque, Counter, defaultdict 等は使えない
    • 未実装
    • 未実装
    • accumulate, combinations 等は使えない
    • 意外と(?)、ほぼ全ての機能を使える

グローバル変数

Numba関数内からグローバル変数にアクセスしても、それがNumbaに対応した型なら使える。

ただし、AOT や jit(cache=True) などでコンパイル結果をキャッシュする場合、グローバル変数はコンパイル当時のもので固定される

glb = 5

@njit('i8()', cache=True)
def global_test():
    return glb

print(global_test())  # => 5

glb = 6

print(global_test())  # => 5

型指定

@njit() の引数には、関数の引数と戻り値の型(Signature)を指定することができる。

指定が無いと、関数が呼ばれたときの引数に応じて型推論してからコンパイルされるのに対し、 型を指定すると指定した型に向けたコンパイルしか行われないので、 「より頑健なコードになる」「8bitで済むような変数に64bit割り当てられるなどが無くなり、より高速化できる」などの恩恵が期待できる。

「戻り値の型(引数1の型, 引数2の型, …)」のように記述する。 numbaで定義された型指定用クラスを使う方法と、文字列を使う方法があるが、文字列を使った方がimportなどの手間が省ける。 戻り値の型を省略した場合は推論される。

文字列とデータ型の対応は以下の通り。

型指定用のクラスは int64 など数値がbit数を表すのに対し、文字列指定は i8 などbyte数を表す点に注意。

また、配列は np.ndarray を指す。

# 戻り値: 無し,   引数: int64, int64
@njit('void(i8, i8)')
def func(a, b):
    print(a + b)


# 戻り値: int8,   引数: float32の配列, complex128の二次元配列
@njit('i1(f4[:], c16[:, :])')


# 戻り値が複数ある場合(同じ型)
@njit('UniTuple(i8[:], 2)()')
def func():
    x = np.zeros(10, np.int64)
    y = np.ones(20, np.int64)
    return x, y


# 戻り値が複数ある場合(異なる型)
@njit('Tuple(i8, f8)()')
def func():
    return 1, 1.0

事前コンパイル(AOT)

numba.pycc モジュールを使うことで、関数をコンパイルした結果を保存することができる。

PCにCコンパイラがインストールされている必要がある。 未インストールの場合、Windowsならこの辺とかを参考に Build Tools for Visual Studio をインストールすればいけるはず。

コンパイル時点で引数・返値の型がわかっている必要がある。 明示的に型を与える場合はよいが、@njit のみを付けて型が未定義の関数はサンプルデータを与えて一度JITを走らせる、という方法がある。

使う際は、通常のモジュールと同様 import すればよい。 コンパイル後に出来るファイルは my_module.cp38-win_amd64.pyd などとプラットフォーム名が付く場合があるが、import 時の記述は import my_module でよい。 IDEでは、「そんなモジュール無いで」という構文エラーが出たり、引数の型補間が行われない可能性はある。

1)
SciPyなどの外部モジュールや、一部標準モジュールは未対応のものがある
2)
何かしら方法はあるかも
本WebサイトはcookieをPHPのセッション識別および左欄目次の開閉状況記憶のために使用しています。同意できる方のみご覧ください。More information about cookies
programming/python/packages/numba.txt · 最終更新: 2020/07/01 by ikatakos
CC Attribution 4.0 International
Driven by DokuWiki Recent changes RSS feed Valid CSS Valid XHTML 1.0