モノグサプログラミングコンテスト2024(AtCoder Beginner Contest 345)E 問題メモ
E - Colorful Subsequence
問題
N 個のボールが1列に並び、i 番目のボールは色 Ci、価値 Vi
ちょうど K 個のボールを取り除き、同じ色のボールが連続しないようにしたい
できるか判定し、できる場合は残るボールの価値の最大値を求めよ
1≤N≤2×105
1≤K≤min(N,500)
解法
まず、以下のDPを考える。
以下のように求められる。
便宜的に、0番目に色がどれとも異なり価値は0のボールがあり、それは必ず使用すると考える。
i→
j | 0 1 2 3 4 5 6
↓--+---------------------
0 | 0 @
1 | - @
2 | - - @
3 | - - - @ [*]
4 | - - - -
DP[6,3] を求める場合、@ をつけた DP[5,3],DP[4,2],... の中で、
かつ「i が Ci≠C6」である中での最大値に、V6 を加えたものとなる。
これを愚直に求めると、O(NK) のマスそれぞれの遷移に O(K) かかり、O(NK2) となり間に合わない。
高速化
更新の際に参照するのは、常にナナメに並ぶ要素であることを利用する。
DP2[i,j,k]=DP[i,∗] までを考え、(i,j) から左斜め上に辿っていった中での、
k=0: 最大値
k=1: 最大値の時の最後に置いた色
k=2: 最大値とは異なる色で、2番目に大きな値
添字 * は、その添字の範囲全体を指すとする。
DP[6,3] を求める場合、DP2[5,3,∗] を参照し、
DP2[5,3,1] が C6 と同じなら DP2[5,3,2] を、異なれば DP1[5,3,0] を使えばよい。
そして、新しい DP[6,∗] をもって、DP2[6,∗] を求める。
この時に DP[6,3] と比較するのは DP2[5,2,∗] であり、答えを求める際に使用した DP2[5,3,∗] とは1段ずれる点に注意。
これで遷移が O(1) となり、O(NK) で間に合うようになる。
Python3
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
import sys
import numpy as np
def solve(inp):
n, k = inp[:2]
ccc = inp[2::2]
vvv = inp[3::2]
INF = 1 << 60
dp1v = np.full(k + 1, -INF, np.int64)
dp1c = np.full(k + 1, -INF, np.int64)
dp2v = np.full(k + 1, -INF, np.int64)
dp1v[0] = 0
for i in range(n):
c = ccc[i]
v = vvv[i]
dp_tmp = np.full(k + 1, -INF, np.int64)
for j in range(min(i + 1, k + 1)):
dp_tmp[j] = dp1v[j] if dp1c[j] != c else dp2v[j]
dp_tmp[j] += v
if i < k:
dp_tmp[i + 1] = 0
for j in range(min(i + 1, k), 0, -1):
if dp1v[j - 1] < dp_tmp[j]:
if dp1c[j - 1] == c:
dp2v[j] = dp2v[j - 1]
dp1v[j] = dp_tmp[j]
dp1c[j] = c
else:
dp2v[j] = dp1v[j - 1]
dp1v[j] = dp_tmp[j]
dp1c[j] = c
elif dp1c[j - 1] != c and dp2v[j - 1] < dp_tmp[j]:
dp2v[j] = dp_tmp[j]
dp1v[j] = dp1v[j - 1]
dp1c[j] = dp1c[j - 1]
else:
dp2v[j] = dp2v[j - 1]
dp1v[j] = dp1v[j - 1]
dp1c[j] = dp1c[j - 1]
dp1v[0] = dp_tmp[0]
dp1c[0] = c
if dp1v[k] < 0:
return -1
else:
return dp1v[k]
SIGNATURE = '(i8[:],)'
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', SIGNATURE)(solve)
cc.compile()
exit()
if os.name == 'posix':
from my_module import solve
else:
from numba import njit
solve = njit(SIGNATURE, cache=True)(solve)
print('compiled', file=sys.stderr)
inp = np.fromstring(sys.stdin.read(), dtype=np.int64, sep=' ')
ans = solve(inp)
print(ans)
|