モノグサプログラミングコンテスト2024(AtCoder Beginner Contest 345)E 問題メモ
E - Colorful Subsequence
問題
N 個のボールが1列に並び、i 番目のボールは色 Ci、価値 Vi
ちょうど K 個のボールを取り除き、同じ色のボールが連続しないようにしたい
できるか判定し、できる場合は残るボールの価値の最大値を求めよ
1≤N≤2×105
1≤K≤min
解法
まず、以下のDPを考える。
以下のように求められる。
便宜的に、0番目に色がどれとも異なり価値は0のボールがあり、それは必ず使用すると考える。
i→
j | 0 1 2 3 4 5 6
↓--+---------------------
0 | 0 @
1 | - @
2 | - - @
3 | - - - @ [*]
4 | - - - -
DP[6,3] を求める場合、@
をつけた DP[5,3],DP[4,2],... の中で、
かつ「i が C_i \neq C_6」である中での最大値に、V_6 を加えたものとなる。
これを愚直に求めると、O(NK) のマスそれぞれの遷移に O(K) かかり、O(NK^2) となり間に合わない。
高速化
更新の際に参照するのは、常にナナメに並ぶ要素であることを利用する。
DP2[i,j,k]=DP[i,*] までを考え、(i,j) から左斜め上に辿っていった中での、
k=0: 最大値
k=1: 最大値の時の最後に置いた色
k=2: 最大値とは異なる色で、2番目に大きな値
添字 *
は、その添字の範囲全体を指すとする。
DP[6,3] を求める場合、DP2[5,3,*] を参照し、
DP2[5,3,1] が C_6 と同じなら DP2[5,3,2] を、異なれば DP1[5,3,0] を使えばよい。
そして、新しい DP[6,*] をもって、DP2[6,*] を求める。
この時に DP[6,3] と比較するのは DP2[5,2,*] であり、答えを求める際に使用した DP2[5,3,*] とは1段ずれる点に注意。
これで遷移が O(1) となり、O(NK) で間に合うようになる。
Python3
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
import sys
import numpy as np
def solve(inp):
n, k = inp[: 2 ]
ccc = inp[ 2 :: 2 ]
vvv = inp[ 3 :: 2 ]
INF = 1 << 60
dp1v = np.full(k + 1 , - INF, np.int64)
dp1c = np.full(k + 1 , - INF, np.int64)
dp2v = np.full(k + 1 , - INF, np.int64)
dp1v[ 0 ] = 0
for i in range (n):
c = ccc[i]
v = vvv[i]
dp_tmp = np.full(k + 1 , - INF, np.int64)
for j in range ( min (i + 1 , k + 1 )):
dp_tmp[j] = dp1v[j] if dp1c[j] ! = c else dp2v[j]
dp_tmp[j] + = v
if i < k:
dp_tmp[i + 1 ] = 0
for j in range ( min (i + 1 , k), 0 , - 1 ):
if dp1v[j - 1 ] < dp_tmp[j]:
if dp1c[j - 1 ] = = c:
dp2v[j] = dp2v[j - 1 ]
dp1v[j] = dp_tmp[j]
dp1c[j] = c
else :
dp2v[j] = dp1v[j - 1 ]
dp1v[j] = dp_tmp[j]
dp1c[j] = c
elif dp1c[j - 1 ] ! = c and dp2v[j - 1 ] < dp_tmp[j]:
dp2v[j] = dp_tmp[j]
dp1v[j] = dp1v[j - 1 ]
dp1c[j] = dp1c[j - 1 ]
else :
dp2v[j] = dp2v[j - 1 ]
dp1v[j] = dp1v[j - 1 ]
dp1c[j] = dp1c[j - 1 ]
dp1v[ 0 ] = dp_tmp[ 0 ]
dp1c[ 0 ] = c
if dp1v[k] < 0 :
return - 1
else :
return dp1v[k]
SIGNATURE = '(i8[:],)'
if sys.argv[ - 1 ] = = 'ONLINE_JUDGE' :
from numba.pycc import CC
cc = CC( 'my_module' )
cc.export( 'solve' , SIGNATURE)(solve)
cc. compile ()
exit()
if os.name = = 'posix' :
from my_module import solve
else :
from numba import njit
solve = njit(SIGNATURE, cache = True )(solve)
print ( 'compiled' , file = sys.stderr)
inp = np.fromstring(sys.stdin.read(), dtype = np.int64, sep = ' ' )
ans = solve(inp)
print (ans)
|