Numpyのオーバーフロー回避
Pythonのintは勝手に多倍長で計算してくれるので、巨大な整数を計算してもオーバーフローの心配は無いが、NumpyはCで実装されているためC言語の型の制約を受ける。
# Pythonのintは問題なく計算できる print( (10 ** 10) * (10 ** 10) ) # => # 100000000000000000000 # Numpy由来の型では警告が出て、誤った値が出力される print( np.int64(10 ** 10) * np.int64(10 ** 10) ) # => # RuntimeWarning: overflow encountered in longlong_scalars # 7766279631452241920 # 記事の内容とはそれるが、numpy.seterr() でエラーとすることもできる np.seterr(over='raise') print( np.int64(10 ** 10) * np.int64(10 ** 10) ) # => # FloatingPointError: overflow encountered in longlong_scalars
回避方法として、以下のいずれかがある。
- object型で扱う
- 自力で多倍長演算的な処理を実装する
object型で扱う
お手軽といえばお手軽だが、遅くなるので、わざわざNumpyで実装する意味が薄くなる方法。
NumpyでもPythonのintをそのまま(Pythonオブジェクトのまま)扱う。
a = np.array([10 ** 10], dtype=object) b = np.array([10 ** 10], dtype=object) print( a * b ) # => # [100000000000000000000]
dtype
を書き換えればいいだけなのでお手軽だが、数値演算もPython側で行われるため、当然遅くなる。
a = np.array([10 ** 10] * 10 ** 8, dtype=object) b = np.array([10 ** 10] * 10 ** 8, dtype=object) c = a * b # 時間計測: 4.298 [s] a = np.array([10 ** 10] * 10 ** 8, dtype=np.int64) b = np.array([10 ** 10] * 10 ** 8, dtype=np.int64) c = a * b # 時間計測: 1.276 [s] # ※値は間違っている
多倍長的な処理を実装する
100桁や1000桁などの長い値に対応するのは、NTTやGarnerの定理を組み合わせたそれなりの処理が必要だが、
競プロでよくあるような 109 程度の剰余を考えればよい場合(かけ算により一時的に 1018 くらいになるのが上限である場合)は、15bitあたりで区切ることでそこそこ手軽に実装できる。
1018 自体は64bit整数で扱えるが、Numpyではさらにその累積値を取るまでを一度に計算してくれるような関数があったりするので、それを絡めるとオーバーフローしてしまう。
行列積
上限 109 程度の値になりうる N×M 配列と M×L 配列をかけ算(行列積)したい場合。
最大 109×109 の結果を M 個足し合わせることになるが、 符号付き64bit整数で表せる上限が 9.2×1018 程度なので、 わずか M≥10 になるともうオーバーフローの可能性が生じてしまう。
桁を2つに分けて、上下で別々に計算したのを復元することができる。
1234 * 5678 = 7006652 上2桁ずつ 12 * 56 = 672 上 * 下 12 * 78 = 936 下 * 上 34 * 56 = 1904 下2桁ずつ 34 * 78 = 2652 672 * 100^2 + (936 + 1904) * 100 + 2652 = 7006652
上例は10進数でやったが、プログラムにおいては2進数で桁を分けるのが適している。
230≥109 のため、各数値を上下15bitで分けて計算し、復元する。Karatsuba法を使うと若干の高速化につながる。
1 2 3 4 5 6 7 8 9 10 11 12 |
def matrix_mul(mat1, mat2, MOD): mask = ( 1 << 15 ) - 1 mat1h = mat1 >> 15 mat1l = mat1 & mask mat2h = mat2 >> 15 mat2l = mat2 & mask mathh = mat1h @ mat2h % MOD matll = mat1l @ mat2l % MOD mathl = (mathh + matll - (mat1h - mat1l) @ (mat2h - mat2l)) % MOD res = (mathh << 30 ) + (mathl << 15 ) + matll res % = MOD return res |
無いと思うが、行列サイズが 1010 とか巨大だった場合はこれでもオーバーフローするので、適宜調整する。
さすがNumpyと言うべきか、自分の環境では愚直に3重ループのfor文を回す処理をNumbaでコンパイルしたものより速かった。
NTTによるたたみ込み
上下bitで分ける処理は、NTTを用いた高速なたたみ込み演算でも活用できる。