Toyota Programming Contest 2024#4(AtCoder Beginner Contest 348)F,G問題メモ
F - Oddly Similar
問題
$N \times M$ の行列 $A$ が与えられる
$A$ の $i$ 行目と $j$ 行目は、「$A_{i,k}=A_{j,k}$ となる $k$($1 \le k \le M$)」の個数が奇数の時、「似ている」という
似ている行のペア $(i,j)$(ただし $i \lt j$)の個数を求めよ
$1 \le N,M \le 2000$
$1 \le A_{i,j} \le 999$
解法
こんなん、愚直に見ていくしか方法が無さそうだが、愚直に見ていくと $O(N^2M)$ かかる。
(C++などではそれで間に合っちゃうらしいが)
一致する“個数”でなく“偶奇”だけがわかればよいことから、bitsetのxorを用いて定数倍を軽くする方法が考えられる。
「2 列目の値が 3 であるような行は、$i=0,3,4$ 行目」といったように、
$(j,A_{i,j})$ の組毎に、合致する $i$ の集合を持たせたデータを考える。
ある列 $j$ で同じ集合に属する行は、互いに $j$ 列目で値が一致していることになる。
ある行 $i$ について、自身と似ている行を探すには、
j 0 1 2 3 4
i
0 1 2 3 4 5
1 1 2 4 4 5
2 2 1 4 3 5
3 5 4 3 2 1
4 2 1 3 4 5
i=4 について
j,A iの集合(自身以外)
43210
(0,2) 00100
(1,1) 00100
(2,3) 01001
(3,4) 00011
(4,5) 00111
------------- xor
01101
このように、着目中の行の各 $(j,A_{i,j})$ に対する $i$ の集合をxorすると、最後に“1”が立っているものが「似ている」列となる。
実際は、$i \lt j$ のペアのみ探すので、$i$ が小さい方から順に、求解→$i$ の集合への登録、を処理していくと重複を除ける。
(自身を $j$ として、自身より小さい $i$ に限定したペアを求める)
$A$ の制約範囲が狭めなので、$(j,A)$ は高々 $2 \times 10^6$ 通りしか取り得ず、
それぞれに $N$ bitのbitsetを持たせても、$\dfrac{4 \times 10^9}{64} \simeq 6.25 \times 10^7$ でメモリ的には大丈夫。
計算量も、$N$ 個の行に付き、$M$ 回、$\dfrac{N}{64}$ サイズのxor演算を行うため、
総合して $\dfrac{N^2M}{64} \simeq 1.25 \times 10^8$ となり、遅い言語では厳しめだが、なんとか間に合う。
Python3
import os
import sys
import numpy as np
def solve(inp):
def bit_count(arr):
t = arr.dtype.type
mask = t(-1)
s55 = t(0x5555555555555555 & mask) # Add more digits for 128bit support
s33 = t(0x3333333333333333 & mask)
s0F = t(0x0F0F0F0F0F0F0F0F & mask)
s01 = t(0x0101010101010101 & mask)
arr = arr - ((arr >> 1) & s55)
arr = (arr & s33) + ((arr >> 2) & s33)
arr = (arr + (arr >> 4)) & s0F
return (arr * s01) >> (8 * (arr.itemsize - 1))
n, m = inp[:2]
k = (n - 1) // 61 + 1
d = np.zeros((m, 1000, k), np.int64)
result = np.zeros((n, k), np.int64)
for i in range(n):
p, q = divmod(i, 61)
bit = 1 << q
for j in range(m):
a = inp[i * m + j + 2]
result[i] ^= d[j, a]
d[j, a, p] |= bit
ans = bit_count(result).sum()
return ans
SIGNATURE = '(i8[:],)'
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', SIGNATURE)(solve)
cc.compile()
exit()
if os.name == 'posix':
# noinspection PyUnresolvedReferences
from my_module import solve
else:
from numba import njit
solve = njit(SIGNATURE, cache=True)(solve)
print('compiled', file=sys.stderr)
inp = np.fromstring(sys.stdin.read(), dtype=np.int64, sep=' ')
ans = solve(inp)
print(ans)
PyPyってたまに、わずかな違いで実行速度に差が生じることがある。
上手にコンパイルしてくれる書き方があるのだろうけど、根本的な要因がわからない。
G - Max (Sum - Max)
問題
長さ $N$ の2つの整数列 $A,B$ が与えられる
$k=1,2,...,N$ について、以下に答えよ
$I=\{1,2,...,N\}$ から $k$ 個選ぶ
「選んだindexの $A_i$ の和」-「選んだindexの $B_i$ の最大値」の最大値を求めよ
$1 \le N \le 2 \times 10^5$
解法
利用するテクニックも難しめだし、この問題がそれを利用できることに気付くのもなかなか難しい。
(類題を見たことあれば「あ、」と気付くものかもしれない)
最小費用流で解けるかと思ったが、引く数が「選んだindexの $B_i$ の最小値」ならいけるが、最大値なのでダメだった。
まず、$(A_i,B_i)$ を $B_i$ を基準にソートする。
こうすると、引く数は「選んだ中で最も右の $B_i$ を使う」と確定するので考えやすくなる。
分割統治
以下のDPを定義して分割統治をする。
初期値として、区間の長さが1の時、$dp[l,l+1,1]=A_l-B_l$ である。
ここから、区間を統合して倍の長さの区間の答えを求める、ということを繰り返して、$dp[0,N,k]$ の答えを求めたい。
(2の冪からはみ出す分は、$A_i=-\infty$、$B_i=\infty$ などで埋めておけばよい)
i 0 1 2 3 4 5 6 7
A 3 -1 4 1 -5 9 2 -∞
B -5 -2 -2 0 1 4 8 ∞
dp 8 1 6 1 -6 5 -6 -∞ ←長さ1の区間 dp[i,i+1,1] = Ai-Bi
|----||----||----||----|
|----------||----------| 区間を倍々にして dp[l,r,*] を求めていく
|----------------------|
$[l,m)$ と $[m,r)$ を統合して $dp[l,r,k]$ を求めるにあたり、$[l,m)$ の方から $s$ 個、取ることにすると、
$s \lt k$ のとき
$[l,m)$ の方からは、$B_i$ を気にする必要は無いので、$A_i$ の大きい方から $s$ 個、貪欲にとった総和としてよい
$[m,r)$ の方からは、$dp[m,r,k-s]$ が達成できる
この両区間の和が達成できる
$s = k$ のとき
$s=0,1,...,k$ を試して、この中の最大値が、$dp[l,r,k]$ となる。
当然、これを $k$ ごとにやっていたら、$k$ は $0,1,...,r-l$ に付いて求めるため、$O((r-l)^2)$ かかる。
分割統治全体では、$r-l$ が $N$ 以上になるまで、$2^2$ が $N/2$ 回、$4^2$ が $N/4$ 回、、、と繰り返されるため、
$O(N^2)$ かかることになる。
最大値畳み込み
$dp[l,r,*]$ を求めるのは、最大値の畳み込みに相当する。
要は、2つの数列 $X,Y$ で、$X_s$ を左区間(区間内の $A_i$ の大きい方から $s=0,1,...$ 個取ったもの)、$Y_t$ を右区間($dp[m,r,t]$)としたときに、
となる数列 $Z$ を求められれば、$dp[l,r,*]$ が一気に求まる。
この畳み込みは、普通は $O(|X||Y|)$ かかるが、
$X,Y$ のいずれかが上に凸である場合、$O(|X|+|Y|\log|X|)$ などで求める方法がある。
今回の場合、左区間の $X_s$ が「大きい方からの累積和」ということで上に凸であるため、この高速化手法を適用できる。
これを使ってdpを倍々に統合していくと、全体で $O(N \log N)$ で求まる。
Python3
from itertools import accumulate
from operator import itemgetter
from typing import Callable, List
def smawk(n: int, m: int, select: Callable[[int, int, int], bool]) -> List[int]:
"""
n×m 配列 A の、各行における最小値(※)を与える列index [j1,j2,...,jn] を得る。
制約:
A はTotally Monotone
A は陽に与える必要は無く、比較関数 select によって評価する
select(i,j,k) := A[i,j] > A[i,k] なら True を返す関数
(※)
最小値以外(最大値など)の指標にも適用できる。
Aはその指標にとっての Totally Monotone であればよく、select(i,j,k) はA[i,k]の方が"よい"場合にTrueを返す関数とする。
参考:
https://github.com/noshi91/Library/blob/master/algorithm/smawk.cpp
"""
columns_history = [list(range(m))]
row_bit_length = n.bit_length()
for depth in range(row_bit_length):
next_row_size = n >> depth
i_step = 1 << depth
i = i_step - 1
columns = []
for j in columns_history[depth]:
while columns and select(i - i_step, columns[-1], j):
columns.pop()
i -= i_step
if len(columns) < next_row_size:
columns.append(j)
i += i_step
columns_history.append(columns)
assert len(columns_history[row_bit_length]) == 1
result = [0] * n
result[(1 << (row_bit_length - 1)) - 1] = columns_history[row_bit_length][0]
for depth in range(row_bit_length - 1, 0, -1):
i_step = 1 << depth
h_step = i_step >> 1
columns = columns_history[depth]
ji = 0
for i in range(h_step - 1, n, i_step):
res = columns[ji]
until = columns[-1] if i + h_step >= n else result[i + h_step]
while columns[ji] < until:
ji += 1
if select(i, res, columns[ji]):
res = columns[ji]
result[i] = res
return result
def get_score(dp, acc, w, l, i, j):
if j == 0:
return dp[l + i - 1]
else:
return acc[l + w] - acc[l + w - (i - j)] + dp[l + w + j - 1]
def solve(n, ab):
ab.sort(key=itemgetter(1))
aaa, bbb = zip(*ab)
# dp[w][i*w+j] = 左端が i*w から始まる長さ w の区間から j+1 個選んだときの最大値 (w=1,2,4,... j=0,1,...,w-1)、w は省略
dp = [a - b for a, b in ab]
acc = [0] + list(accumulate(aaa))
w2 = 2
while (w2 >> 1) <= n:
w = w2 >> 1
ndp = [0] * n
sorted_aaa = [0] * n
l = 0
while l < n:
mid = l + w
if mid < n:
r = min(l + w2, n)
m = r - l
right_limit = r - mid
def select(i, j, k):
i += 1
if i - j > w:
return True
if i < k or k > right_limit:
return False
return get_score(dp, acc, w, l, i, j) < get_score(dp, acc, w, l, i, k)
res = smawk(m, m + 1, select)
for i in range(r - l):
ndp[l + i] = get_score(dp, acc, w, l, i + 1, res[i])
# aaa のソート
i = l
j = mid
k = l
while i < mid and j < r:
if aaa[i] <= aaa[j]:
sorted_aaa[k] = aaa[i]
i += 1
k += 1
else:
sorted_aaa[k] = aaa[j]
j += 1
k += 1
if i < mid:
sorted_aaa[k:r] = aaa[i:mid]
else:
sorted_aaa[k:r] = aaa[j:r]
else:
ndp[l:] = dp[l:]
sorted_aaa[l:] = aaa[l:]
l += w2
dp = ndp
aaa = sorted_aaa
acc = [0] + list(accumulate(aaa))
w2 <<= 1
return dp
n = int(input())
ab = [tuple(map(int, input().split())) for _ in range(n)]
ans = solve(n, ab)
print('\n'.join(map(str, ans)))