i 0 1 2 3 4 5 6 7
C -2 -2 4 4 4 -3 -3 -3
B -2 -4 0 4 8 5 2 -1
A -2 -6 -6 -2 6 11 13 12
~~~~ ←最大値
$A$ の最大値を求める上では、$B_i$ と比べて $B_{i+1}$ が正なら次も足した方がいいし、負ならそこでひとまず止めた方がいい。
よって、数列の長さは非常に長くなりうるが、最大値をチェックするのは $B$ が正から負に切り替わるタイミング(と、最初と最後)だけでよい。
それがいつ切り替わるかを直接求めるのは難しいが、$B$ における $N$ 個の $x_i,y_i$ の切り替わり点を追っていけばわかる。
j1 j2 j3
i 0 1 2 3 4 5 6 7
C -2 -2 4 4 4 -3 -3 -3
B -4 8 -1
切り替わり点の $i$ を $j_1,j_2,...$ と表す。
$B_{j_k}$ は、単に $x_k \times y_k$ を累積的に足していけば求められる。
$B_{j_k}$ と $B_{j_{k+1}}$ を比べると、その中に $A$ の最大値があり得るかどうかわかる。
正→負の区間だけチェックすればよいとわかった。
この中で切り替わる具体的な箇所 $i$ を特定し、そこから最大値候補である $A_i$ を求めたい。
切り替わり点から $i$ が1つ進む毎に $x_k$ だけ減っていくので、$\dfrac{B_{j_{k-1}}}{-x_k}$(切り捨て)だけ進んだ箇所が正、その次が負となる。
j1 j2 j3
i 0 1 2 3 4 5 6 7
C -2 -2 4 4 4 -3 -3 -3
B -4 8→ 5→ 2→-1
B[j2]/-x3 = 8/3 = 2 より、
i=4 から 2 進んだ i=6 がこの区間の中で最大を取る箇所
前の区間の終わり $i=4$ から $d=2$ 進んだ箇所が最大値をとるとわかったので、$A_6$ を求める。
先頭から累積的に計算することで、$A_{j_1},A_{j_2},...$ は計算できる。
j1 j2 j3
i 0 1 2 3 4 5 6 7
C -2 -2 4 4 4 -3 -3 -3
B -4 8 -1
A -6 6 [] 12
1つ前の区間の終わりが $A_r$ で、その次に $C$ において $n$ 個の $x_i$ が並んでいると、
と表現できる。
これを使うと、$A_6 = A_4 + B_4 \times d + \dfrac{d(d+1)}{2} \times x_3 = 13$ となり、正しく求められている。
正→負となる区間毎に調べていって、$O(N)$ で全て求められる。
Python3
t = int(input())
for _ in range(t):
n, m = map(int, input().split())
ccc = [tuple(map(int, input().split())) for _ in range(n)]
bbb = [0]
for x, y in ccc:
bbb.append(bbb[-1] + x * y)
aaa = [0]
ans = ccc[0][0]
for i in range(n):
x, y = ccc[i]
aaa.append(aaa[-1] + bbb[i] * y + y * (y + 1) // 2 * x)
l = bbb[i]
r = bbb[i + 1]
if r >= 0:
ans = max(ans, aaa[-1])
elif l <= 0:
ans = max(ans, aaa[-1])
else:
d = l // -x
tmp = aaa[-2] + bbb[i] * d + d * (d + 1) // 2 * x
ans = max(ans, tmp)
print(ans)
$N$ 回ちょうどで $(0,0)$ から $(X,Y)$ に移動する。
操作回数が最低限の $|X|+|Y|$ に足りなかったり、偶奇が合わない場合は最初に除く。
最低限から余る移動回数については、
無駄な移動(同じ軸の正方向と負方向の移動のセット)を $k=\dfrac{N-(|X|+|Y|)}{2}$ 回、しなくちゃいけない。
素直に考えると、$x$ 軸方向で何回、$y$ 軸方向で何回行うか決めれば計算できる。
片方の無駄を $0~k$ に固定して場合分けすればよい。
しかし、これだと $O(k)$ かかる。
テクニックとして、45度回転させ、$(+1,+1),(+1,-1),(-1,+1),(-1,-1)$ の4つの移動で $(X+Y,X-Y)$ に到達する、と問題を言い換えると、
2つの軸を独立に考えることができ、2つの二項係数の積で $O(1)$ で計算できてしまう。
今回の制約では、単純に1次元の無駄回数を $0~k$ で固定して、2次元に帰着させるのを全通り試せばよい。
最初の階乗計算を除けば、全体 $O(N)$ で求められる。
Python3
import numba
import numpy as np
@numba.njit('i8(i8,i8,i8,i8)', cache=True)
def solve(n, x, y, z):
def mod_pow(x, a, MOD):
ret = 1
cur = x
while a:
if a & 1:
ret = ret * cur % MOD
cur = cur * cur % MOD
a >>= 1
return ret
def precompute_factorials(n, MOD):
factorials = np.ones(n + 1, np.int64)
for m in range(1, n + 1):
factorials[m] = factorials[m - 1] * m % MOD
inversions = np.ones(n + 1, np.int64)
inversions[n] = mod_pow(factorials[n], MOD - 2, MOD)
for m in range(n, 1, -1):
inversions[m - 1] = inversions[m] * m % MOD
return factorials, inversions
def ncr(n, r, facts, finvs, MOD):
return facts[n] * finvs[r] % MOD * finvs[n - r] % MOD
MOD = 998244353
x = abs(x)
y = abs(y)
z = abs(z)
t = x + y + z
if n < t:
return 0
d, m = divmod(n - t, 2)
if m == 1:
return 0
facts, finvs = precompute_factorials(n, MOD)
ans = 0
for dk in range(d + 1):
z_cnt = z + dk * 2
xy_cnt = n - z_cnt
pz = facts[n] * finvs[z_cnt - dk] % MOD * finvs[dk] % MOD * finvs[xy_cnt] % MOD
r1 = (xy_cnt - (x - y)) // 2
r2 = (xy_cnt - (x + y)) // 2
pxy = ncr(xy_cnt, r1, facts, finvs, MOD) * ncr(xy_cnt, r2, facts, finvs, MOD) % MOD
ans = (ans + pxy * pz) % MOD
return ans
n, x, y, z = map(int, input().split())
ans = solve(n, x, y, z)
print(ans)