ACL Beginner Contest C,D,E,F問題メモ
C - Connect Cities
問題
解法
Union-Find。
既存の道路による接続を全てuniteしたあと、残っている根(リーダー)の個数を数える。
$K$ 個のリーダーが残っていたら、それを繋ぐために必要な道路の個数は $K-1$ 本。
Python3
class UnionFind:
def __init__(self, n):
self.table = [-1] * n
def _root(self, x):
stack = []
tbl = self.table
while tbl[x] >= 0:
stack.append(x)
x = tbl[x]
for y in stack:
tbl[y] = x
return x
def find(self, x, y):
return self._root(x) == self._root(y)
def unite(self, x, y):
r1 = self._root(x)
r2 = self._root(y)
if r1 == r2:
return
d1 = self.table[r1]
d2 = self.table[r2]
if d1 <= d2:
self.table[r2] = r1
self.table[r1] += d2
else:
self.table[r1] = r2
self.table[r2] += d1
def get_size(self, x):
return -self.table[self._root(x)]
n, m = map(int, input().split())
uft = UnionFind(n)
for _ in range(m):
a, b = map(int, input().split())
a -= 1
b -= 1
uft.unite(a, b)
ans = 0
for p in uft.table:
if p < 0:
ans += 1
print(ans - 1)
D - Flat Subsequence
問題
数列 $A_1,A_2,...,A_N$ と整数 $K$ が与えられる
以下の条件を満たす数列 $B$ の長さとして考えられる最大値を求めよ
$1 \le N \le 3 \times 10^5$
$0 \le A_i \le 3 \times 10^5$
解法
セグメント木を動的に使ってのDP。
以下のDPを定義する。
すると、$i+1$ について更新したければ、$DP[i][A_{i+1}-K]~DP[i][A_{i+1}+K]$ の範囲から最大の長さを持ってきて、そこに+1すればよい。
これは、区間MAXの取得と、一点更新を行えるデータ構造があればできる。セグメント木で実装する。
$DP[i+1]$ についての更新は $DP[i]$ の情報しか不要なので、破壊的に更新していってよい。
Python3
import sys
class SegTreeMax:
"""
以下のクエリを処理する
1.update: i番目の値をxに更新する
2.get_max: 区間[l, r)の最大値を得る
"""
def __init__(self, n, e):
"""
:param n: 要素数
:param e: 初期値(maxの場合は入りうる値より必ず小さな値)
"""
n2 = 1 << (n - 1).bit_length()
self.offset = n2
self.tree = [e] * (n2 << 1)
self.e = e
@classmethod
def from_array(cls, arr, e):
ins = cls(len(arr), e)
ins.tree[ins.offset:ins.offset + len(arr)] = arr
for i in range(ins.offset - 1, 0, -1):
l = i << 1
r = l + 1
ins.tree[i] = max(ins.tree[l], ins.tree[r])
return ins
def update(self, i, x):
"""
i番目の値をxに更新
:param i: index(0-indexed)
:param x: update value
"""
i += self.offset
self.tree[i] = x
while i > 1:
y = self.tree[i ^ 1]
if y >= x:
break
i >>= 1
self.tree[i] = x
def get_max(self, a, b):
"""
[a, b)の最大値を得る
:param a: index(0-indexed)
:param b: index(0-indexed)
"""
result = self.e
l = a + self.offset
r = b + self.offset
while l < r:
if r & 1:
result = max(result, self.tree[r - 1])
if l & 1:
result = max(result, self.tree[l])
l += 1
l >>= 1
r >>= 1
return result
n, k, *aaa = map(int, sys.stdin.buffer.read().split())
sgt = SegTreeMax(300001, 0)
for a in aaa:
nxt = sgt.get_max(max(0, a - k), min(a + k + 1, 300001))
sgt.update(a, nxt + 1)
print(sgt.get_max(0, 300001))
E - Replace Digits
問題
$N$ 個の '1
' が繋がった文字列 $S$ がある
$Q$ 個のクエリを処理する
各クエリの後、$S$ を10進数の整数とみなした値を $\mod{998244353}$ で求めよ
$1 \le N,Q \le 2 \times 10^5$
解法
遅延伝播セグメント木。
$111111 = 100000+10000+1000+100+10+1$ と分けて、1桁ずつ管理する。
まず、セグメント木のそれぞれのノードが表す1単位を計算する。
たとえば6桁なら、こんな感じ。$S$ の初期状態でもある。
| 111111 |
| 111100 | 11 |
| 110000 | 1100 | 11 | 0 |
|100000| 10000| 1000| 100| 10| 1| 0 | 0 |
これを $U_i$ とすると、「ノード $i$ 以下全体が $D$ に書き換えられた」=「そのノード以下全体の合計は $U_i \times D$」となる。
そして、遅延伝播セグメント木には、
というデータを乗っける。遅延データは当てはまらない場合は-1とかにしておけばいい。
そして、遅延セグ木に定義すべき演算は、以下のようになる。
op(データ同士の演算)
mapping(遅延データをデータに反映)
composition(遅延データ同士の合成)
感想
適当に実装したところ上手く動かなかったので、いちから流れを確かめてバグ取りしてたら時間かかった。
どうも、「今ある値に加える」処理と「今ある値を書き換える」処理の区別が付いておらず、後者に対して実装の仕方が分かってなかった。
1)
書き換えるタイプの更新処理では上記のように、
「ノード以下が全て書き換えられるか」というフラグを持たせ、
Trueなら遅延データを参照して、Falseなら現在のデータを参照する、というのが肝要っぽい。
遅延データ側に載せるものには $mapping(a,e)=a$ を満たす単位元のような存在 $e$ が必要らしいが、これが「フラグがFalseであること」に相当するのか。なるほど。
Python3
import os
import sys
import numpy as np
def solve(inp):
SEGTREE_TABLES = []
COMMON_STACK = np.zeros(10 ** 7, dtype=np.int64)
MOD = 998244353
def bit_length(n):
ret = 0
while n:
n >>= 1
ret += 1
return ret
def segtree_init(n):
n2 = 1 << bit_length(n)
table = np.zeros((n2 << 1, 3), np.int64)
# 問題依存
k = 1
for i in range(n2 + n - 1, n2 - 1, -1):
table[i, 0] = k
k = k * 10 % MOD
for i in range(n2 - 1, 0, -1):
ch = i << 1
table[i, 0] = (table[ch, 0] + table[ch + 1, 0]) % MOD
table[:, 1] = -1
table[:, 2] = table[:, 0]
SEGTREE_TABLES.append(table)
return len(SEGTREE_TABLES) - 1
def segtree_debug_print(ins):
table = SEGTREE_TABLES[ins]
offset = table.shape[0] >> 1
for t in range(table.shape[1]):
i = 1
while i <= offset:
print(table[i:2 * i, t])
i <<= 1
def segtree_eval(table, offset, i):
d = table[i, 1]
if d == -1:
return
if i < offset:
ch = i << 1
table[ch, 1] = d
table[ch + 1, 1] = d
table[i, 0] = table[i, 2] * d % MOD
table[i, 1] = -1
def segtree_bottomup(table, i):
lch = i << 1
rch = lch + 1
l_dat = table[lch, 0] if table[lch, 1] == -1 else table[lch, 1] * table[lch, 2] % MOD
r_dat = table[rch, 0] if table[rch, 1] == -1 else table[rch, 1] * table[rch, 2] % MOD
table[i, 0] = (l_dat + r_dat) % MOD
def segtree_range_update(ins, l, r, d):
table = SEGTREE_TABLES[ins]
offset = table.shape[0] >> 1
stack = COMMON_STACK
stack[:3] = (1, 0, offset)
si = 3
updated = []
while si:
i, a, b = stack[si - 3:si]
segtree_eval(table, offset, i)
if b <= l or r <= a:
si -= 3
continue
if l <= a and b <= r:
table[i, 1] = d
si -= 3
continue
updated.append(i)
m = (a + b) // 2
stack[si - 3:si] = (i << 1, a, m)
stack[si:si + 3] = ((i << 1) + 1, m, b)
si += 3
while updated:
i = updated.pop()
segtree_bottomup(table, i)
n = inp[0]
q = inp[1]
lll = inp[2::3] - 1
rrr = inp[3::3]
ddd = inp[4::3]
ins = segtree_init(n)
buf = []
for i in range(q):
l = lll[i]
r = rrr[i]
d = ddd[i]
segtree_range_update(ins, l, r, d)
buf.append(SEGTREE_TABLES[ins][1, 0])
return buf
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', '(i8[:],)')(solve)
cc.compile()
exit()
if os.name == 'posix':
# noinspection PyUnresolvedReferences
from my_module import solve
else:
from numba import njit
solve = njit('(i8[:],)', cache=True)(solve)
print('compiled', file=sys.stderr)
inp = np.fromstring(sys.stdin.read(), dtype=np.int64, sep=' ')
ans = solve(inp)
print('\n'.join(map(str, ans)))
F - Heights and Pairs
問題
$2N$ 人の人がいて、人 $i$ の身長は $h_i$ である
以下の条件を満たしつつ、2人ずつ $N$ 個のペアを作りたい
$N$ 個のペアの作る方法としてあり得るものは何通りか、$\mod 998244353$ で求めよ
$1 \le N \le 50000$
$1 \le h_i \le 10^5$
解法
たたみ込みを使った包除原理。
まずこのような問題は、全てのペアを作るパターン数から、当てはまらないものを除くことができるか考える。
すると、同じ身長となってしまうペアの個数で包除原理を使えそう。
包除原理の適用
まず、$2K$ 人から $K$ 個のペアを作るパターン数というのは、1個飛ばしの階乗(二重階乗)で求められる。これを最初に計算しておく。
次に、同じ身長のペアが少なくとも $k$ 個できてしまうパターン数を、$k=0,1,2,...$ で求める。
たとえば「身長10の人が7人」「身長20の人が10人」他の身長は1人ずついたとして、以下を計算していく。
すると、答えは $Pat[0]-Pat[1]+Pat[2]-Pat[3]+...$ で求められる。
問題は、$Pat$ をどう求めるか。ここにたたみ込みを用いる。
たたみ込み
同じ身長の人が $n$ 人いたとして(これを $n$ 人グループと称する)、この中だけで、同じ身長のペアを $k$ 組作るパターン数を考える。
これは、$n$ 人の中からどの人をペアにするかで ${}_{n}C_{2k}$、その $2k$ 人をどう組み合わせるかで $Pair[k]=(2k-1)!!$、これを掛け合わせたものとなる。
たとえば7人なら、以下のようになる。
7人グループからk個のペアを作る方法
k 0 1 2 3
1 21 105 105
次に、グループが2つあったとして、その中から同じ身長のペアを合計 $k$ 組作るパターン数を考える。
たとえば7人グループと10人グループがあったとして、
10人グループからk個のペアを作る方法
k 0 1 2 3 4 5
1 45 630 3150 4725 945
「2つのグループのどちらかから、同じ身長のペアを合計 $k$ 組作るパターン数」は、たとえば $k=4$ なら以下のようになる。
7人から 10人から
のペア数 のペア数
0 4 1 x 4725 = 4725
1 3 21 x 3150 = 66150
2 2 105 x 630 = 66150
3 1 105 x 45 = 4725
-----------------------
141750
この操作は、まさにたたみ込みである。
たたみ込んだ結果は以下のようになる。
k 0 1 2 3 4 ... 8
1 66 1680 21210 141750 ... 99225
MODを取りながらのたたみ込みは、AtCoder-Libraryのconvolutionを使うと2つの配列長の和を $N$ として $O(N \log{N})$ で求めることができる。
同じ身長のグループが3個以上でも、このようなマージを順番に適用していくことで求められる。
ただし、マージにかかる計算量は、前述の通り配列の合計の長さに比例する。
適当に順番を決めてしまうと、何度も長い配列をマージすることになってしまい、TLE。
サイズの短い方から順番に行っていくと、全体で $O(N (\log{N})^2)$ になる。
「マージは小さい方から」。よく忘れて、見当違いの高速化に走りがち。
感想
マージする順番以外は(コンテスト終了後に)自力でたどり着けたが、
正直、たたみ込みに気付いたのはAtCoder-Library Contestだからというのが大きい気がする。
Python3
import os
import sys
from heapq import heappop, heappush
import numpy as np
# https://github.com/atcoder/ac-library/blob/master/atcoder/convolution.hpp
def solve(inp):
def ceil_pow2(n):
x = 0
while 1 << x < n:
x += 1
return x
def bit_scan_forward(n):
x = 0
while n & 1 == 0:
n >>= 1
x += 1
return x
def pow_mod(x, n, m):
r = 1
y = x % m
while n:
if n & 1:
r = (r * y) % m
y = y * y % m
n >>= 1
return r
def get_primitive_root(m):
if m == 2: return 1
if m == 167772161: return 3
if m == 469762049: return 3
if m == 754974721: return 11
if m == 998244353: return 3
divs = [2]
x = (m - 1) // 2
while x & 1 == 0:
x >>= 1
i = 3
while i * i <= x:
if x % i == 0:
divs.append(i)
x //= i
while x % i == 0:
x //= i
i += 2
if x > 1:
divs.append(x)
g = 2
while True:
ok = True
for d in divs:
if pow_mod(g, (m - 1) // d, m) == 1:
ok = False
break
if ok:
return g
def butterfly_prepare(mod, primitive_root):
sum_e = np.zeros(30, np.int64)
sum_ie = np.zeros(30, np.int64)
es = np.zeros(30, np.int64)
ies = np.zeros(30, np.int64)
cnt2 = bit_scan_forward(mod - 1)
e = pow_mod(primitive_root, (mod - 1) >> cnt2, mod)
ie = pow_mod(e, mod - 2, mod)
for i in range(cnt2, 1, -1):
es[i - 2] = e
ies[i - 2] = ie
e = e * e % mod
ie = ie * ie % mod
now_e = 1
now_ie = 1
for i in range(cnt2 - 1):
sum_e[i] = es[i] * now_e % mod
sum_ie[i] = ies[i] * now_ie % mod
now_e = now_e * ies[i] % mod
now_ie = now_ie * es[i] % mod
return sum_e, sum_ie
def butterfly(aaa, mod, sum_e):
n = aaa.size
h = ceil_pow2(n)
for ph in range(1, h + 1):
w = 1 << (ph - 1)
p = 1 << (h - ph)
now = 1
for s in range(w):
offset = s << (h - ph + 1)
for i in range(p):
l = aaa[i + offset]
r = aaa[i + offset + p] * now % mod
aaa[i + offset] = (l + r) % mod
aaa[i + offset + p] = (l - r) % mod
now = now * sum_e[bit_scan_forward(~s)] % mod
def butterfly_inv(aaa, mod, sum_ie):
n = aaa.size
h = ceil_pow2(n)
for ph in range(h, 0, -1):
w = 1 << (ph - 1)
p = 1 << (h - ph)
inow = 1
for s in range(w):
offset = s << (h - ph + 1)
for i in range(p):
l = aaa[i + offset]
r = aaa[i + offset + p]
aaa[i + offset] = (l + r) % mod
aaa[i + offset + p] = ((l - r) * inow) % mod
inow = inow * sum_ie[bit_scan_forward(~s)] % mod
MOD = 998244353
primitive_root = get_primitive_root(MOD)
sum_e, sum_ie = butterfly_prepare(MOD, primitive_root)
def convolution(aaa, bbb, MOD):
n = aaa.size
m = bbb.size
z = 1 << ceil_pow2(n + m - 1)
raaa = np.zeros(z, np.int64)
rbbb = np.zeros(z, np.int64)
raaa[:n] = aaa
rbbb[:m] = bbb
butterfly(raaa, MOD, sum_e)
butterfly(rbbb, MOD, sum_e)
ccc = raaa * rbbb % MOD
butterfly_inv(ccc, MOD, sum_ie)
iz = pow_mod(z, MOD - 2, MOD)
result = ccc[:n + m - 1] * iz % MOD
return result
def mod_pow(x, a, MOD):
ret = 1
cur = x
while a:
if a & 1:
ret = ret * cur % MOD
cur = cur * cur % MOD
a >>= 1
return ret
def precompute_factorials(n, MOD):
factorials = np.ones(n + 1, dtype=np.int64)
for m in range(2, n + 1):
factorials[m] = factorials[m - 1] * m % MOD
inversions = np.ones(n + 1, dtype=np.int64)
inversions[n] = mod_pow(factorials[n], MOD - 2, MOD)
for m in range(n, 2, -1):
inversions[m - 1] = inversions[m] * m % MOD
return factorials, inversions
def counter(arr):
cnt = {0: 0}
for a in arr:
if a in cnt:
cnt[a] += 1
else:
cnt[a] = 1
if cnt[0] == 0:
del cnt[0]
return cnt
n = inp[0]
hhh = inp[1:]
hhh_cnt = counter(hhh)
same_heights = []
max_pair_cnt = 0
max_same_height = 0
for c in hhh_cnt.values():
if c <= 1:
continue
max_pair_cnt += c // 2
max_same_height = max(max_same_height, c)
same_heights.append(c)
pairs = np.zeros(n + 1, np.int64)
pairs[0] = 1
for i in range(1, n + 1):
pairs[i] = pairs[i - 1] * (i * 2 - 1) % MOD
facts, finvs = precompute_factorials(max_same_height, MOD)
same_heights.sort()
pc = -1
cmbs = [np.ones(1, np.int64)]
q = [(1, 0)]
for i, c in enumerate(same_heights):
if pc == c:
cmb = cmbs[-1].copy()
else:
cmb = np.zeros(c // 2 + 1, np.int64)
for j in range(c // 2 + 1):
cmb[j] = facts[c] * finvs[j * 2] % MOD * finvs[c - j * 2] % MOD * pairs[j] % MOD
pc = c
cmbs.append(cmb)
q.append((cmb.size, i + 1))
while len(q) > 1:
_, i = heappop(q)
_, j = heappop(q)
cmbs[i] = convolution(cmbs[i], cmbs[j], MOD)
heappush(q, (cmbs[i].size, i))
result = cmbs[q[0][1]]
ans = 0
for i, v in enumerate(result):
if i & 1:
ans = (ans - v * pairs[n - i]) % MOD
else:
ans = (ans + v * pairs[n - i]) % MOD
return ans
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', '(i8[:],)')(solve)
cc.compile()
exit()
if os.name == 'posix':
# noinspection PyUnresolvedReferences
from my_module import solve
else:
from numba import njit
solve = njit('(i8[:],)', cache=True)(solve)
print('compiled', file=sys.stderr)
inp = np.fromstring(sys.stdin.read(), dtype=np.int64, sep=' ')
ans = solve(inp)
print(ans)