nCr mod m の求め方
競技プログラミングでは、答えの非常に大きくなる解に対しては、オーバーフローを防ぐため「○○で割った剰余で答えよ」という指定が入ることがある。○○は素数であることが多い。(109+7など。以下mとする)
また、解法に組み合わせ数nCrを用いる計算も、出てくることがある。
これらが組み合わさったとき、計算量を減らすため、mod演算の定理を用いたテクニックがある。
組み合わせ数
まず、組み合わせ数の公式は、以下になる。
(nr)=nCr=n!r!(n−r)!
なので、上限 n が決まっている場合は、1!~n! をあらかじめ求めておくと(当然その中にr!や(n−r)!も含まれる)同じ計算を繰り返さずに済む。
この時、n! も相当に大きな数になるためmodを取りながら計算しないとオーバーフローする。 しかし、残念ながらmodの世界では加減乗(+−×)はいいが割り算は正常に機能しない。
オーバーフローせず、かつ正しい値になるように、事前計算できないか。
モジュラ逆数
xのモジュラ逆数とは、「modの世界で x にかけたら1になる数」である。x−1 で表記し、これをかけることで割り算を表現する。
普通の算数では、x の逆数は 1/x である。x×1x=1となる。
だが、modは整数論の世界なので、分数は扱えない。しかし、逆数に相当する数なら条件付きだが存在する。
3×x−1≡1(mod11)
3×4=12≡1(mod11)
試しに x−1 に4を入れてみると成立する。4は、mod11における3のモジュラ逆数の1つということになる。
逆数が存在する条件は、x と m が互いに素である場合に限られる。m が素数の場合には、1≤x<m の範囲では必ず満たすので、必ず存在する。
で、逆数を事前計算しておけば、組み合わせ数は以下の式で求められる。
nCr≡n!×r!−1×(n−r)!−1(modm)
フェルマーの定理
では逆数を求めるにはどうしたらよいのか。m が素数の場合、フェルマーの小定理が利用できる。
am−1≡1(modm)
a−1≡am−2(modm) (m≥3)
m が素数の場合、am−2 が a の逆数となる。これを事前計算しておくことで、modでの組み合わせ数の算出を高速化できる。
この計算も、m−2 が大きいと時間がかかるが、バイナリ法やモンゴメリ乗算などを使うことで高速に求められる。
1!~n!の逆数
いくら a−1 が高速に求められるとしても、1!−1~n!−1 の全てをこの方法で求めていてはやはり時間がかかる。そこで、
n!−1=11×2×...×(n−1)×n(n−1)!−1=11×2×...×(n−1)=n!−1×n(n−2)!−1=11×2×...×(n−2)=(n−1)!−1×(n−1)
であることを利用すれば、n!−1 さえ求めれば、かけ算の繰り返しで全ての階乗の逆数を求められる。
まとめると、(modm は省略)
- 1から順にかけ算して、1!~n! を計算・保持
- n!−1≡n!m−2 を計算
- n!−1 に n~1 を逆順にかけ算して、(n−1)!−1~1!−1 を計算・保持
- 1!~n! と 1!−1~n!−1 より、nCr を計算
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
def prepare(n, MOD): # 1! - n! の計算 f = 1 factorials = [ 1 ] # 0!の分 for m in range ( 1 , n + 1 ): f * = m f % = MOD factorials.append(f) # n!^-1 の計算 inv = pow (f, MOD - 2 , MOD) # n!^-1 - 1!^-1 の計算 invs = [ 1 ] * (n + 1 ) invs[n] = inv for m in range (n, 1 , - 1 ): inv * = m inv % = MOD invs[m - 1 ] = inv return factorials, invs |
nCr の左項には n しか来ない場合、1!~(n-1)!は保持しなくてよいバージョン
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
def prepare(n, MOD): # n! の計算 f = 1 for m in range ( 1 , n + 1 ): f * = m f % = MOD fn = f # n!^-1 の計算 inv = pow (f, MOD - 2 , MOD) # n!^-1 - 1!^-1 の計算 invs = [ 1 ] * (n + 1 ) invs[n] = inv for m in range (n, 1 , - 1 ): inv * = m inv % = MOD invs[m - 1 ] = inv return fn, invs |
さらなる高速化
階乗の計算は、NumPyを用いて √N×√N 行列の縦横1列をまとめて計算することで高速化が可能になる。 (ただし上記サイトにもあるが、ややトリッキーな方法であり、競プロを外れた文脈ではCython, Numbaなどで高速化した方が素直)
modを取りながらの累積積を高速化できるので、N=106 で、だいたいNumPyを使わないコードと比較して2.5倍(332ms→125ms)くらいの速度になる。 200msの差が生きるときもあるかも知れない。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import numpy as np def prepare(n, MOD): nrt = int (n * * 0.5 ) + 1 nsq = nrt * nrt facts = np.arange(nsq, dtype = np.int64).reshape(nrt, nrt) facts[ 0 , 0 ] = 1 for i in range ( 1 , nrt): facts[:, i] = facts[:, i] * facts[:, i - 1 ] % MOD for i in range ( 1 , nrt): facts[i] = facts[i] * facts[i - 1 , - 1 ] % MOD facts = facts.ravel().tolist() invs = np.arange( 1 , nsq + 1 , dtype = np.int64).reshape(nrt, nrt) invs[ - 1 , - 1 ] = pow (facts[ - 1 ], MOD - 2 , MOD) for i in range (nrt - 2 , - 1 , - 1 ): invs[:, i] = invs[:, i] * invs[:, i + 1 ] % MOD for i in range (nrt - 2 , - 1 , - 1 ): invs[i] = invs[i] * invs[i + 1 , 0 ] % MOD invs = invs.ravel().tolist() return facts, invs |