競技プログラミングでは、答えの非常に大きくなる解に対しては、オーバーフローを防ぐため「○○で割った剰余で答えよ」という指定が入ることがある。○○は素数であることが多い。($10^9+7$など。以下$m$とする)
また、解法に組み合わせ数${}_n\mathrm{C}_r$を用いる計算も、出てくることがある。
これらが組み合わさったとき、計算量を減らすため、mod演算の定理を用いたテクニックがある。
まず、組み合わせ数の公式は、以下になる。
$$\binom{n}{r} = {}_n\mathrm{C}_r = \dfrac{n!}{r!(n-r)!}$$
なので、上限 $n$ が決まっている場合は、$1!~n!$ をあらかじめ求めておくと(当然その中に$r!$や$(n-r)!$も含まれる)同じ計算を繰り返さずに済む。
この時、$n!$ も相当に大きな数になるためmodを取りながら計算しないとオーバーフローする。 しかし、残念ながらmodの世界では加減乗($+-\times$)はいいが割り算は正常に機能しない。
オーバーフローせず、かつ正しい値になるように、事前計算できないか。
$x$のモジュラ逆数とは、「modの世界で $x$ にかけたら1になる数」である。$x^{-1}$ で表記し、これをかけることで割り算を表現する。
普通の算数では、$x$ の逆数は $1/x$ である。$x \times \dfrac{1}{x} = 1$となる。
だが、modは整数論の世界なので、分数は扱えない。しかし、逆数に相当する数なら条件付きだが存在する。
$$3 \times x^{-1} \equiv 1 \pmod{11}$$
$$3 \times 4 = 12 \equiv 1 \pmod{11}$$
試しに $x^{-1}$ に4を入れてみると成立する。4は、mod11における3のモジュラ逆数の1つということになる。
逆数が存在する条件は、$x$ と $m$ が互いに素である場合に限られる。$m$ が素数の場合には、$1 \le x \lt m$ の範囲では必ず満たすので、必ず存在する。
で、逆数を事前計算しておけば、組み合わせ数は以下の式で求められる。
$${}_n\mathrm{C}_r \equiv n! \times r!^{-1} \times (n-r)!^{-1} \pmod{m}$$
では逆数を求めるにはどうしたらよいのか。$m$ が素数の場合、フェルマーの小定理が利用できる。
$$a^{m-1} \equiv 1 \pmod{m}$$
$$a^{-1} \equiv a^{m-2} \pmod{m} \ \ (m \ge 3)$$
$m$ が素数の場合、$a^{m-2}$ が $a$ の逆数となる。これを事前計算しておくことで、modでの組み合わせ数の算出を高速化できる。
この計算も、$m-2$ が大きいと時間がかかるが、バイナリ法やモンゴメリ乗算などを使うことで高速に求められる。
いくら $a^{-1}$ が高速に求められるとしても、$1!^{-1}~n!^{-1}$ の全てをこの方法で求めていてはやはり時間がかかる。そこで、
\begin{align} n!^{-1} &= \frac{1}{1 \times 2 \times ... \times (n-1) \times n} \\ (n-1)!^{-1} &= \frac{1}{1 \times 2 \times ... \times (n-1)} = n!^{-1} \times n \\ (n-2)!^{-1} &= \frac{1}{1 \times 2 \times ... \times (n-2)} = (n-1)!^{-1} \times (n-1) \end{align}
であることを利用すれば、$n!^{-1}$ さえ求めれば、かけ算の繰り返しで全ての階乗の逆数を求められる。
まとめると、($\mod{m}$ は省略)
def prepare(n, MOD): # 1! - n! の計算 f = 1 factorials = [1] # 0!の分 for m in range(1, n + 1): f *= m f %= MOD factorials.append(f) # n!^-1 の計算 inv = pow(f, MOD - 2, MOD) # n!^-1 - 1!^-1 の計算 invs = [1] * (n + 1) invs[n] = inv for m in range(n, 1, -1): inv *= m inv %= MOD invs[m - 1] = inv return factorials, invs
${}_n\mathrm{C}_r$ の左項には $n$ しか来ない場合、1!~(n-1)!は保持しなくてよいバージョン
def prepare(n, MOD): # n! の計算 f = 1 for m in range(1, n + 1): f *= m f %= MOD fn = f # n!^-1 の計算 inv = pow(f, MOD - 2, MOD) # n!^-1 - 1!^-1 の計算 invs = [1] * (n + 1) invs[n] = inv for m in range(n, 1, -1): inv *= m inv %= MOD invs[m - 1] = inv return fn, invs
階乗の計算は、NumPyを用いて $\sqrt{N} \times \sqrt{N}$ 行列の縦横1列をまとめて計算することで高速化が可能になる。 (ただし上記サイトにもあるが、ややトリッキーな方法であり、競プロを外れた文脈ではCython, Numbaなどで高速化した方が素直)
modを取りながらの累積積を高速化できるので、$N=10^6$ で、だいたいNumPyを使わないコードと比較して2.5倍(332ms→125ms)くらいの速度になる。 200msの差が生きるときもあるかも知れない。
import numpy as np def prepare(n, MOD): nrt = int(n ** 0.5) + 1 nsq = nrt * nrt facts = np.arange(nsq, dtype=np.int64).reshape(nrt, nrt) facts[0, 0] = 1 for i in range(1, nrt): facts[:, i] = facts[:, i] * facts[:, i - 1] % MOD for i in range(1, nrt): facts[i] = facts[i] * facts[i - 1, -1] % MOD facts = facts.ravel().tolist() invs = np.arange(1, nsq + 1, dtype=np.int64).reshape(nrt, nrt) invs[-1, -1] = pow(facts[-1], MOD - 2, MOD) for i in range(nrt - 2, -1, -1): invs[:, i] = invs[:, i] * invs[:, i + 1] % MOD for i in range(nrt - 2, -1, -1): invs[i] = invs[i] * invs[i + 1, 0] % MOD invs = invs.ravel().tolist() return facts, invs