UNICORNプログラミングコンテスト2022(AtCoder Beginner Contest 269) G,Ex問題メモ
G - Reversible Cards 2
問題
$N$ 枚のカードがあり、それぞれ表に $A_i$、裏に $B_i$ の整数が書かれている
$\sum_{i=1}^{N}(A_i+B_i)=M$ として、$k=0,1,2,...,M$ について、以下の答えを求めよ
$1 \le N \le 2 \times 10^5$
$1 \le M \le 2 \times 10^5$
$0 \le A_i,B_i \le M$
解法
$S=\sum A_i$ とする。これが初期状態で上になっている面に書かれた整数の和である。
$D_i=B_i-A_i$ とする。カード $i$ を裏返すと、これだけ総和が増減する。
以下のナップサック(負の価値がありえるver.)をすると答えが求まるが、$O(NM)$ のためTLEとなる。
高速化したい。ここで、以下の性質を利用できる。
$D_i$ の「種類数」をなるべく多くするようなテストケースを考えると、
$(A_i,B_i)=(0,1),(1,0),(0,2),(2,0),(0,3),(3,0),...$ のようなものがある。
(※(0,0)はいくらでも増やせるし、意味が無いので省略)
このとき $D_i=1,-1,2,-2,3,-3,...$ であり、$A_i+B_i=|D_i|$、$\sum_i(A_i+B_i)=M \le 2 \times 10^5$ である。
制約内で、$|D_i|=t$ となるところまで作れるとして、$M \ge \dfrac{t(t+1)}{2} \times 2$ となる。
$t \lt \sqrt{M}$ であり、種類数は $O(\sqrt{M})$ となる。
次に、$D_i$ が同じものをまとめて遷移させる方法を考える。$D_i=d$ となる $i$ が $x_d$ 個あったとする。
同じ品物が複数個あるナップサック問題のテクニックとして、2の冪乗のグループを作るというものがある。
$x_d=1+2+4+8+...+2^k+端数$ と分解し、それぞれの個数分をまとめた $log{x_d}$ 個のグループに分ける。
すると、$1~x_d$ のどのような値もいくつかのグループの組合せで実現でき、遷移回数を $O(\log{x_d})$ まで落とせる。
こうすることで計算量は、単純な見積もりで $O(M \sqrt{M} \log{M})$ となる。
実際には「$M$ の種類数」と「$D_i$ が同じものの個数の多さ」はトレードオフの関係にあるので、より少なくなる。
ちゃんとした計算量見積もりは公式Editorial参照。$O(N+M^{1.5})$ となるらしい。
Python3
import os
import sys
import numpy as np
def solve(inp):
n, m = inp[:2]
aaa = inp[2::2]
bbb = inp[3::2]
s = aaa.sum()
diff = bbb - aaa
diff_count = {n: n}
diff_count.clear()
for d in diff:
if d in diff_count:
diff_count[d] += 1
else:
diff_count[d] = 1
INF = 1 << 60
dp = np.full(m + 1, INF, np.int64)
dp[s] = 0
for d, x in diff_count.items():
if d == 0:
continue
k = 1
while x:
ndp = dp.copy()
c = min(k, x)
if d > 0:
for a in range(m + 1 - d * c):
b = a + d * c
ndp[b] = min(ndp[b], dp[a] + c)
else:
for a in range(-d * c, m + 1):
b = a + d * c
ndp[b] = min(ndp[b], dp[a] + c)
dp = ndp
x -= c
k <<= 1
dp[dp == INF] = -1
return dp
SIGNATURE = '(i8[:],)'
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', SIGNATURE)(solve)
cc.compile()
exit()
if os.name == 'posix':
# noinspection PyUnresolvedReferences
from my_module import solve
else:
from numba import njit
solve = njit(SIGNATURE, cache=True)(solve)
print('compiled', file=sys.stderr)
inp = np.fromstring(sys.stdin.read(), dtype=np.int64, sep=' ')
ans = solve(inp)
print('\n'.join(map(str, ans)))
Ex - Antichain
問題
$N$ 頂点の、頂点1を根とした根付き木が与えられる
$k=1,2,...,N$ について、以下の条件を満たす頂点集合 $S$ の個数を $\mod{998244353}$ で求めよ
$2 \le N \le 2 \times 10^5$
実行制限時間: 8sec.
解法
素朴な木DP(TLE)
以下の木DPを定義し、
ある頂点から見て異なる子に属する頂点同士は、互いに影響しないので、それぞれ独立に自由に選べる。
よって子のDPの結果をFFTを使って合成していく、という解法が思いつくが、ムカデグラフのような場合に上手くいかない。
○
/\ ← ③合成後サイズ N のFFT
○ ○
/\ ← ②合成後サイズ N-2 のFFT
○ ○
/\ ← ①合成後サイズ N-4 のFFT
○ ○
/\ (丸付き数字は処理順)
: :
既にサイズが肥大化した配列を何度もFFTすることになってしまうので、$O(N^2)$ となる。
こういう時、DPの計算が結合法則を満たすなら、長いpathを分割統治できるHL分解が有効になる場合がある。
HL分解
ある1つの Heavy-path を考える。各頂点にくっつくHeavy-path以外の頂点を、Light-children とする。
○--○--○--○--○--○--○--○ ←Heavy path
| | |\ | | | |
○ ○ ○○ ○ : : : ←Light children
`○ `○
以下を定義する。
$DP_1[l,r,k]$
$DP_2[l,r,k]$
Heavy-path上の区間 $[l,r]$ 上の頂点の、Light-children以下の部分森および $[l,r]$ 上の頂点からの選び方の個数
ただし $[l,r]$ 上の頂点は必ず選ばれている
l r
◎--◎--◎--◎--○--○--○--○ DP1[l,r]=●からの選び方
| | |\ | | | | DP2[l,r]=●および◎からの選び方、ただし◎からは必ず選ぶ
● ● ●● ● ○ ○ ○
`● `●
こうすると結合法則が成り立ち、Heavy path上のDPを半分ずつ統合していくことができる。
l1 r1 l2 r2
◎--◎--◎--◎--○--○--○--○
| | | | | | |
: : : : : : :
DP1[l1,r2] = DP1[l1,r1] * DP1[l2,r2]
DP2[l1,r2] = DP1[l1,r1] * DP2[l2,r2] + DP2[l1,r1]
※ただし * は添字 k に対する畳み込み
Heavy-path の最左頂点を $L$、最右頂点を $R$ として、$DP[L] = DP1[L,R]+DP2[L,R]$ が、Heavy-path全体の結果となる。
この結果を $L$ の親に伝えれば、親の Heavy-path のDPを計算するのに必要な情報となる。
P
○--○--○--○--○--○ 親のHeavy-path
| | | ↑↖
: : : │ ○--○--○--○--○--○ Light-child1のHeavy-path
│ L1
○--○--○--○--○--○ Light-child2のHeavy-path
L2
DP1[P,P] = DP[L1] * DP[L2] * ...(子の数だけ畳み込み)
DP2[P,P] = [0, 1]
DP1の畳み込みも、考えなくやると冒頭のムカデグラフと同様の問題が発生しうる。
要素数が小さい方から優先的におこなっていくことで、全体で $O(N \log{N})$ となる。
Heavy-path上の畳み込みも、隣り合った箇所のサイズが小さい方から上手くやると $O(N \log{N})$ になるらしいが、
順番を崩さないようにする必要があり難しいので、特に工夫せず2個ずつ統合していく実装でも、$O(N \log^2{N})$ となる。
実行制限時間が長いので、$O(N \log^2{N})$ でも間に合う。
Python3
from heapq import heapify, heappop, heappush
from typing import List
import numpy as np
def heavy_light_decomposition(n: int, children: List[List[int]], root: int = 0) -> List[List[int]]:
"""
①--②--④--⑦--⑨ → [[1,2,4,7,9], [10], [3,5,8], [6]]
`--③--⑤-⑧ `-⑩
`--⑥
Heavy-path のリストに分解する。
Heavy-path の並びは、(最も重い子を最初に訪れるような)オイラーツアーの訪問順となる。
"""
weights = [-1] * n
q = [root]
while q:
u = q[-1]
if weights[u] == -1:
weights[u] = -2
q.extend(children[u])
else:
q.pop()
weights[u] = 1 + sum(weights[v] for v in children[u])
q = [root]
progress = [0] * n
current_path = []
result = [current_path]
while q:
u = q[-1]
if progress[u] == 0:
children[u].sort(key=weights.__getitem__, reverse=True)
current_path.append(u)
if progress[u] >= len(children[u]):
q.pop()
continue
v = children[u][progress[u]]
if progress[u] > 0:
current_path = []
result.append(current_path)
progress[u] += 1
q.append(v)
return result
def _convolve(f, g):
# 小さいうちは愚直の方が高速。convolve_mod()と併用するなら大丈夫だが、単独で使用する場合はオーバーフローに注意
if len(f) * len(g) < 10000:
return np.convolve(f, g)
fft_len = 1
true_len = len(f) + len(g) - 1
while fft_len < true_len:
fft_len <<= 1
Ff = np.fft.rfft(f, fft_len)
Fg = np.fft.rfft(g, fft_len)
Fh = Ff * Fg
h = np.fft.irfft(Fh, fft_len)
h = np.rint(h).astype(np.int64)
return h[:true_len]
def convolve_mod(f, g, p):
f1, f2 = np.divmod(f, 1 << 15)
g1, g2 = np.divmod(g, 1 << 15)
a = _convolve(f1, g1) % p
c = _convolve(f2, g2) % p
b = (_convolve(f1 + f2, g1 + g2) - (a + c)) % p
h = (a << 30) + (b << 15) + c
return h % p
n = int(input())
ppp = list(map(int, input().split()))
children = [[] for _ in range(n)]
parent = [-1] * n
for i, p in enumerate(ppp, start=1):
p -= 1
children[p].append(i)
parent[i] = p
MOD = 998244353
hld_paths = heavy_light_decomposition(n, children, 0)
light_children_results = [[] for _ in range(n)]
default_dp1 = np.array([1], np.int64)
default_dp2 = np.array([0, 1], np.int64)
ans = []
for heavy_path in reversed(hld_paths):
merge_array = []
for v in heavy_path:
if len(light_children_results[v]) == 0:
merge_array.append((default_dp1.copy(), default_dp2.copy()))
elif len(light_children_results[v]) == 1:
lcr = light_children_results[v][0]
merge_array.append((lcr, default_dp2.copy()))
else:
merge2_array = light_children_results[v]
merge2_queue = []
for i, lcr in enumerate(merge2_array):
merge2_queue.append((len(lcr), i))
heapify(merge2_queue)
while len(merge2_queue) > 1:
l1, i1 = heappop(merge2_queue)
l2, i2 = heappop(merge2_queue)
lcr3 = convolve_mod(merge2_array[i1], merge2_array[i2], MOD)
i3 = len(merge2_array)
merge2_array.append(lcr3)
heappush(merge2_queue, (len(lcr3), i3))
lcr = merge2_array[-1]
merge_array.append((lcr, default_dp2.copy()))
while len(merge_array) > 1:
new_array = []
l = len(merge_array)
for i in range(0, l, 2):
if i + 1 == l:
new_array.append(merge_array[i])
continue
dp11, dp12 = merge_array[i]
dp21, dp22 = merge_array[i + 1]
dp31 = convolve_mod(dp11, dp21, MOD)
dp32 = convolve_mod(dp11, dp22, MOD)
dp32[:len(dp12)] += dp12
dp32 %= MOD
new_array.append((dp31, dp32))
merge_array = new_array
res1, res2 = merge_array[0]
if len(res1) >= len(res2):
res1[:len(res2)] += res2
res = res1 % MOD
else:
res2[:len(res1)] += res1
res = res2 % MOD
leader = heavy_path[0]
if leader != 0:
p = parent[leader]
light_children_results[p].append(res)
else:
ans = res
ans = ans.tolist()
if len(ans) < n + 1:
ans.extend([0] * (n + 1 - len(ans)))
print('\n'.join(map(str, ans[1:])))