モンゴメリ乗算
a \times b \pmod{n}の答えを高速に求めるアルゴリズム。a,b,nは整数、0 \le a,b \lt n。数字が大きくなっても平気。
上記の計算のどこに時間かかるって、コンピュータは、加算は得意、乗算もそこそこできるけど、除算は四則演算の中では苦手。モンゴメリ乗算は、mod nの部分を工夫して、除算を1回のみで行えるようにしてしまう。
1回のみだったら変わらないじゃん、となるのだが、a \times b \times c \times d ... \pmod{n}と複数回かけても1回の除算で済む。このような場合、正攻法では、数字が大きくなるとオーバーフローの問題や、桁数に比例して計算時間がかかるため、現実的には((a \times b \bmod n) \times c \bmod n) \times d \bmod n ...と、乗算毎に割ることになる。それが1回で済むようになるため、効果が大きい。
概要としては、うまく式変形することで、\bmod nの式からnによる除算を取り除き、任意の定数Rによる除算と剰余(と、その他除算以外の演算)に置き換えてしまう。このRは任意に選べるので、コンピュータの得意な2の冪乗数にすることで、実質的にビット演算で行えるようになる。
我々も割り算は苦手でも、123456789 \div 1000 = 123456 \ \text{あまり} \ 789のように、割る数が10,100,1000…なら数字を区切るだけで求められるのと同じ。
詳細な説明はWikiとか他のサイトに譲るとして、pythonのコード。冪剰余も求められる。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
class Montgomery: def __init__( self , n): self .n = n self .nb = n.bit_length() # Rを、Nより大きい最小の2の冪乗数とする # R^2 mod n : この1回だけ除算が必要になる self .r2 = ( 1 << ( self .nb * 2 )) % n # Rを2の冪乗とすることで、mod Rをビットマスクで求められるようになる self .mask = ( 1 << self .nb) - 1 # N * N' = -1 mod R となるN'の導出 # Rを2の冪乗とすることで加算とビットシフトで求められるようになる self .nr = 0 t = 0 vi = 1 for _ in range ( self .nb): if t & 1 = = 0 : t + = n self .nr + = vi t >> = 1 vi << = 1 def reduction( self , t): """モンゴメリリダクション""" c = t * self .nr c & = self .mask c * = self .n c + = t c >> = self .nb if c > = self .n: c - = self .n return c def mul( self , a, b): """a * b mod n を計算""" return self .reduction( self .reduction(a * b) * self .r2) def exp( self , a, b): """a ^ b mod n を計算""" p = self .reduction(a * self .r2) x = self .reduction( self .r2) y = b while y: if y & 1 : x = self .reduction(x * p) p = self .reduction(p * p) y >> = 1 return self .reduction(x) |
ぶっちゃけ
pythonには、a^b \pmod{n}の計算ならpow(a,b,n)があるから、わざわざ自分で書く必要性は無いんだけどね。組み込みな分、圧倒的に速いし。
1 2 3 4 5 6 7 8 |
import timeit mg = Montgomery( 1000000007 ) print (timeit.timeit( 'mg.exp(123456789, 987654321)' , number = 1000 , globals = globals ())) # => 0.0361 print (timeit.timeit( 'pow(123456789, 987654321, 1000000007)' , number = 1000 , globals = globals ())) # => 0.0037 |