AtCoder Beginner Contest 208 E,F問題メモ
E - Digit Products
問題
解法
問題文が桁DPの形をしている。
$k$ が $0~K$ の範囲を取るので一見計算量的に無理そうだが、
実際は1桁の数の積なので素因数分解すると $2^a3^b5^c7^d$ と表せる数に限られる。
$a=0~\log_2{K}, b=0~\log_3{K}, ...$ の範囲を取るので、多く見積もっても $O((\log{K})^4)$ で収まる。
これとは別に「just: $N$ の上位 $i$ 桁目までの積」も管理しておけば、更新は以下の3つで行える。
$DP[i-1]$ で既に管理されているパターン
上位 $i-1$ 桁までの積が $m$ の時、$i$ 桁目を $t=0~9$ のいずれにするかで10通りに遷移する。
$mt \le K$ の時
$mt \gt K$ の時
積が $K$ 以下になるには、これ以降の桁に 0 が出てくるしかない
→0 が出てくるパターンをその場で答えに足し、DPの管理からは除外する
$ans += f(i) \times DP[i-1][m]$
上位 $i-1$ 桁が $N$ と一致し、$i$ 桁目で初めて小さいことが確定するパターン
DPに新しく追加される。
$N$ の $i$ 桁目が $x$ の時、
DPで管理する数の $i$ 桁目は $t=0~(x-1)$ にすることができる。
上記と同様に遷移する。ただし新規に追加されるパターン数は1通りとなる。
$just \times t \le K$ の時
$just \times t \gt K$ の時
$i$ 桁目が最上位桁のパターン
DPに新しく追加される。
$i \ge 2$ において、$t=1~9$ の9通りに遷移する。
$t \le K$ の時
$t \gt K$ の時
最終的に
途中でansに加えられた数と、DPで管理されている数をあわせたものが答え。
ただし、$N$ そのものが各桁の積が $K$ 以下である場合は数えられていないので、
$just \le K$ であれば答えに1足す。
10進数なので各状態からの遷移は $D=10$ 通り ずつあり、計算量は $O(\log{N}(\log{K})^4D)$ となる。
Python3
工夫1
定数 $INF$ を $K$ より大きい数(何でもよい)として定義しておく。
$mk \gt K$ となった場合、遷移先は $DP[i][INF] += DP[i-1][m]$ として1つの状態にまとめてしまえばよい。
この方法なら $f(i)$ を考える必要が無いため、追加実装が少なく抑えられる。
F - Cumulative Sum
問題
非負整数 $n,m$ に対して関数 $f(n,m)$ を正整数 $K$ を用いて以下で定義する
$\displaystyle f(n, m) = \begin{cases} 0 & (n = 0) \newline n^K & (n \gt 0, m = 0) \newline f(n-1, m) + f(n, m-1) & (n \gt 0, m \gt 0) \end{cases}$
$N,M,K$ が与えられるので、$f(N,M)$ を $\mod{10^9+7}$ で求めよ
$0 \le N \le 10^{18}$
$0 \le M \le 30$
$1 \le K \le 2.5 \times 10^6$
解法
制約が $N,M,K$ で意味ありげに異なっていて、特に $M$ が小さい。
問題ページのサンプル1の説明にも表があるが、
$n \gt 0, m \gt 0$ の部分に関しては、よくある経路数え上げのように「上と左の値を足す」という操作になっている。
で、左端が $n^K$、上端が 0 で初期化される感じ。
K=2
n\m 0 1 2 3 4
0 0 0 0 0 0
1 1 1 1 1 1
2 4 5 6 7 8
3 9 14 20 27 35
左端が特殊なものの、遷移は経路数え上げなので、二項係数が出てきそう。
左端の値毎に表を分離すると以下のようになり、
それぞれの表は二項係数の表(をシフトしたもの)に [ ]
でくくった数をかけたものとなっている。
n\m 0 1 2 3 4
0 0 0 0 0 0
1 [1] 1 1 1 1
2 0 1 2 3 4
3 0 1 3 6 10
n\m 0 1 2 3 4
0 0 0 0 0 0
1 0 0 0 0 0
2 [4] 4 4 4 4
3 0 4 8 12 16
n\m 0 1 2 3 4
0 0 0 0 0 0
1 0 0 0 0 0
2 0 0 0 0 0
3 [9] 9 9 9 9
$f(n,m)$ はその総和なので、以下で表せる。
……が、$N$ が巨大すぎてこんなんまともに計算できない。どうすんの?
ラグランジュ補間
知識ゲー的なところはあるが。
$x$ の $D$ 次多項式 $f(x)$ は、$D+1$ 個のサンプル点 $(x_i, f(x_i))$ を与えることで一意に決まる。
それを利用して、$f$ の形を機械的に特定し、
他の $x$ を代入しても答えを求められるようにするアルゴリズムが、ラグランジュ補間になる。
先ほどの $f(n,m)$ は、$m=M$ を固定すると $n$ の多項式として表せる。
まずΣの中の部分は、$i^K$ はそのまま $i$ の $K$ 次式、
${}_{n-i+M-1}C_{M-1}$ は展開すると $i$ の $M-1$ 次式になるので、
あわせて $K+M-1$ 次式となっている。
それの和を1から $n$ まで取ると、次数は1個増えて $n$ の $K+M$ 次式となることが証明されている。
従って、$K+M+1$ 個の $n$ に対する $f(n,M)$ を愚直に計算して、ラグランジュ補間すれば答えが求まる。
この時、一般的にラグランジュ補間は $O(D^2)$ かかってしまうが、
サンプル点の $n$ を適切に選ぶ($0~K+M$ の連番にする)ことで、
1つの値を求めるだけなら $O(D)$ で計算できるテクニックがあるので、それを使う。
Python3
import os
import sys
import numpy as np
def solve(n, m, k):
MOD = 10 ** 9 + 7
def mod_pow(x, a):
ret = np.ones_like(x)
cur = x
while a:
if a & 1:
ret = ret * cur % MOD
cur_ = cur
cur = cur * cur_ % MOD
a >>= 1
return ret
def prepare_factorials(n):
factorials = np.ones(n + 1, np.int64)
for m in range(1, n + 1):
factorials[m] = factorials[m - 1] * m % MOD
inversions = np.ones(n + 1, np.int64)
inversions[n] = mod_pow(factorials[n:], MOD - 2)[0]
for m in range(n, 1, -1):
inversions[m - 1] = inversions[m] * m % MOD
return factorials, inversions
def lagrange(arr, x):
d = arr.size
facts, finvs = prepare_factorials(d)
left_product = np.ones(d + 1, np.int64)
right_product = np.ones(d + 1, np.int64)
for i in range(d):
left_product[i + 1] = left_product[i] * (x - i) % MOD
for i in range(d - 1, -1, -1):
right_product[i] = right_product[i + 1] * (x - i) % MOD
result = 0
for i in range(d):
tmp = left_product[i] * right_product[i + 1] % MOD
tmp = tmp * finvs[i] % MOD
tmp = tmp * finvs[d - i - 1] % MOD
tmp = tmp * arr[i] % MOD
if (d - i - 1) % 2 == 0:
result += tmp
else:
result -= tmp
result %= MOD
return result
l = m + k + 5
base = np.arange(l, dtype=np.int64)
arr = mod_pow(base, k)
for _ in range(m):
arr = arr.cumsum() % MOD
if n < arr.size:
return arr[n]
n %= MOD
return lagrange(arr, n)
SIGNATURE = '(i8,i8,i8)'
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', SIGNATURE)(solve)
cc.compile()
exit()
if os.name == 'posix':
# noinspection PyUnresolvedReferences
from my_module import solve
else:
from numba import njit
solve = njit(SIGNATURE, cache=True)(solve)
print('compiled', file=sys.stderr)
n, m, k = map(int, input().split())
ans = solve(n, m, k)
print(ans)