AtCoder Beginner Contest 221 E,F,G,H問題メモ
E - LEQ
問題
長さ $N$ の数列 $A_1,A_2,...,A_N$
長さ2以上の部分列のうち、(最初の要素) $\le$ (最後の要素) となるものの個数を $\mod{998244353}$ で求めよ
部分列として同じであっても、取り出す添字が異なるものは区別する
$2 \le N \le 3 \times 10^5$
解法
たとえば取り出す部分列の左端と右端のindexを $(2,7)$ とすると、
これを左端、右端ともに全て調べて合計すると答えとなるが、計算量が $O(N^2)$ となってしまう。
各右端につき、左端はまとめて計算できないか考える。
転倒数や増加部分列と似た考え方が適用できそう。
これらの典型問題では、$A_i$ を添字に持つFenwick Treeなどを用意することで、
「$i$ より左に位置する、$A_i$ 以下の値の個数」を高速に得られる。
【iより左に位置するAi以下の値の個数を、各iについて求める】
i 0 1 2 3 4 5 6 7 8
A = [ 2 7 1 2 4 5 7 4 9 ... ] (※以降、Aiは0-indexとする)
Ai 0 1 2 3 4 5 6 7 ...
Tree 0 1 2 0 1 0 0 1 i=0~4 までのAiの個数を反映
~~~~~~~~~~~~~~~~ i=5 より左に出てくる
A5=5 以下の数の個数は、波線部の合計=4
Tree 0 1 2 0 1 1 0 1 "5" の位置に+1して次へ
~~~~~~~~~~~~~~~~~~~~~~ i=6 より左に出てくる
A6=7 以下の数の個数は、波線部の合計=6
Tree 0 1 2 0 1 1 0 2 "7" の位置に+1して次へ
...
繰り返し
これと似たことをしたい。
今回は、たとえば右端を $i=5$ とするのだったら、
$j=0$ の “2” を左端とするときは、$2^4$ 個
$j=1$ の “7” を左端にはできない(が、“7” の位置に足しておくのは $2^3$ 個)
$j=2$ の “1” を左端とするときは、$2^2$ 個
$j=3$ の “2” を左端とするときは、$2^1$ 個
$j=4$ の “4” を左端とするときは、$2^0$ 個
と、位置毎に $2^{i-j-1}$ が足し込まれていると、今回の問題の答えが求められる。
具体的には、$16+4+2+1=23$ 個となる。
ただし、次の $i=6$ を考える際には、
$j=0$ の “2” を左端とするときは、$2^5$ 個
$j=1$ の “7” を左端とするときは、$2^4$ 個
$j=2$ の “1” を左端とするときは、$2^3$ 個
…
のように、足されている数は2倍ずつになっていないと正しく求まらない。
実際に毎回2倍にするわけにはいかないので、これは累積和を取った後で調整する。
つまり、
$A_0=2$ の位置には $1$
$A_1=7$ の位置には $\frac{1}{2}$
$A_2=1$ の位置には $\frac{1}{2^2}$
$A_3=2$ の位置には $\frac{1}{2^3}$
…
このように値を加算しておく。
こうすると、$i=5$ のとき、Fenwick Tree上で $5$ 以下の値の合計は
$1 + \dfrac{1}{2^2} + \dfrac{1}{2^3} + \dfrac{1}{2^4} = \dfrac{23}{2^4}$ となるが、
ここに $2^{i-1}$ をかけることで、正しい値である $23$ が得られる。
これを各右端につき合計すると答えとなる。
Python3
class FenwickTreeInjectable:
def __init__(self, n, identity_factory, func):
self.size = n
self.tree = [identity_factory() for _ in range(n + 1)]
self.func = func
self.idf = identity_factory
self.depth = n.bit_length()
def add(self, i, x):
i += 1
tree = self.tree
func = self.func
while i <= self.size:
tree[i] = func(tree[i], x)
i += i & -i
def sum(self, i):
i += 1
s = self.idf()
tree = self.tree
func = self.func
while i > 0:
s = func(s, tree[i])
i -= i & -i
return s
def lower_bound(self, x, lt):
"""
累積和がx以上になる最小のindexと、その直前までの累積和
:param lt: lt(a, b) で a < b ならTrueを返す関数
"""
total = self.idf()
pos = 0
tree = self.tree
func = self.func
for i in range(self.depth, -1, -1):
k = pos + (1 << i)
if k > self.size:
continue
new_total = func(total, tree[k])
if lt(new_total, x):
total = new_total
pos += 1 << i
return pos + 1, total
def debug_print(self):
prev = 0
arr = []
for i in range(self.size):
curr = self.sum(i)
arr.append(curr - prev)
prev = curr
print(arr)
n = int(input())
aaa = list(map(int, input().split()))
MOD = 998244353
aaa_list = sorted(set(aaa))
aaa_dict = {a: i for i, a in enumerate(aaa_list)}
fwt = FenwickTreeInjectable(n, int, lambda a, b: (a + b) % MOD)
pow2 = [1]
pow2inv = [1]
INV2 = pow(2, MOD - 2, MOD)
for _ in range(n + 5):
pow2.append(pow2[-1] * 2 % MOD)
pow2inv.append(pow2inv[-1] * INV2 % MOD)
ans = 0
for i in range(n):
a = aaa[i]
j = aaa_dict[a]
ans += fwt.sum(j) * pow2[i - 1]
ans %= MOD
fwt.add(j, pow2inv[i])
print(ans)
F - Diameter set
問題
解法
どんなケースが当てはまるか、過不足無く見つけるのが難しい。
例えば以下のような木だと、
①-, ,-③
○-○-○- ... -○-○-④
②-' `-⑤
なので6通りとなる。
だが、このようなグループがたくさんある木も考えられる。
,-①
○
②-, | ,-④
③-○-⑪-○-⑤
| `-⑥
⑦-○-⑩
⑧'`⑨
①、②③、④⑤⑥、⑦⑧⑨⑩がそれぞれグループとなっており、各グループからは多くとも1つしか選べない。
このグループをどうやって見つけるか?
グループの見つけ方
木の直径を見つけるアルゴリズムといえばBFSを2回行う方法がある。
たとえば、適当な頂点から1回目のBFSを行った結果、①が直径をなす頂点の1つとわかったとする。
①から2回目のBFSを行って、最も遠い頂点として②~⑩が得られる。
だが、そのどれとどれがグループなのかを判別しないといけない。
これには、木の直径が持つ1つの性質を利用する。
木の中心とは、直径をなすパスの真ん中の点または辺のことを指す。
上記の例では、⑪が中心である。
木の中心から、直径をなす頂点までの距離を「木の半径」という。
まず木の中心を求め、そこから伸びる各辺ごとに、
中心からの距離が木の半径と一致する頂点を求めれば、それがグループとなる。
⑪から↑方向の辺の先には①の1点
⑪から←方向の辺の先には②③の2点
⑪から→方向の辺の先には④⑤⑥の3点
⑪から↓方向の辺の先には⑦⑧⑨⑩の4点
木の直径が奇数の場合は辺が木の中心となるが、
便宜的に辺の間に $N+1$ 個目の頂点を追加するようにつなぎ替えれば、頂点と同様に考えられる。
①--②------③--④
↓
①--②--⑤--③--④
木の中心から、各辺方向の距離を求める際には、TLEに注意。
スターグラフなど、木の中心から多数の辺が伸びていることがある。
1つの辺ごとに、毎回BFSを初期化して、$N$ 要素の距離配列を作って、などとしていると、いつの間にか $O(N^2)$ の計算量になっていることがある。
なるべく配列は使い回すなど、全体として $O(N)$ に収まるように実装する。
パターン数の求め方
木の中心からそれぞれの辺の先に、1,2,3,4個のグループが存在するとわかった。
各グループにつき、「どれか1点を選ぶ」か「1点も選ばない」かの、(size+1) 通りの選択肢がある。
これをまず掛け合わせる。
ただしこの中には、1点以下の頂点しか選ばれていないものが混ざっている。
その数は、
なので、これを掛け合わせた結果から除けば、答えとなる。
Python3
from collections import deque
def first_bfs(links):
q = deque()
q.append(0)
stacked = [False] * n
stacked[0] = True
v = -1
while q:
v = q.popleft()
for u in links[v]:
if stacked[u]:
continue
q.append(u)
stacked[u] = True
return v
def second_bfs(links, s):
dists = [-1] * n
q = deque()
q.append(s)
dists[s] = 0
predecessors = [-1] * n
v = -1
while q:
v = q.popleft()
for u in links[v]:
if dists[u] != -1:
continue
q.append(u)
dists[u] = dists[v] + 1
predecessors[u] = v
md = max(dists)
assert dists[v] == md
if md % 2 == 0:
hd = md // 2
i = v
while i != s:
if dists[i] == hd:
return i, -1, hd
i = predecessors[i]
else:
hd = (md + 1) // 2
i = v
while i != s:
if dists[i] == hd:
return i, predecessors[i], hd
i = predecessors[i]
return -1, -1, -1
def third_bfs(links, s, hd):
def sub(t):
if hd == 1:
return 1
q = deque()
q.append(t)
dists[t] = 0
count = 0
while q:
v = q.popleft()
for u in links[v]:
if dists[u] != -1:
continue
dists[u] = dists[v] + 1
q.append(u)
if dists[u] == hd - 1:
count += 1
return count
res = 1
ex = 0
dists = [-1] * (n + 1)
dists[s] = -2
for t in links[s]:
st = sub(t)
res = res * (st + 1) % MOD
ex += st
res = (res - ex - 1) % MOD
return res
n = int(input())
MOD = 998244353
links = [set() for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
u -= 1
v -= 1
links[u].add(v)
links[v].add(u)
v1 = first_bfs(links)
c1, c2, hd = second_bfs(links, v1)
if c2 == -1:
assert c1 != -1
ans = third_bfs(links, c1, hd)
else:
links[c1].remove(c2)
links[c2].remove(c1)
links[c1].add(n)
links[c2].add(n)
links.append({c1, c2})
ans = third_bfs(links, n, hd)
print(ans)
G - Jumping sequence
問題
2次元座標の原点 $(0,0)$ からちょうど $N$ 回、上下左右いずれかへのジャンプを繰り返し、$(A,B)$ に移動したい
$i$ 回目のジャンプで移動できる距離は $D_i$ と決まっている
可能か判定し、可能ならジャンプした方向の列を1つ求めよ
$1 \le N \le 2000$
$1 \le D_i \le 1800$
解法
数式で表現すると、2つの $\{-1,0,1\}$ をとる係数列 $k=(k_1,k_2,...,k_N)$ と $l=(l_1,l_2,...,l_N)$ を、
となるように決定したいのだが、
この時に「$k_i=0$ なら $l_i=-1 or 1$」「$k_i=-1 or 1$ なら $l_i=0$」でなければならないという制約がある。
このため独立に考えることが出来ず、DPを考えるにしても「$x$ が○○のときに $y$ は△△が可能」という2次元で行わなければならず、TLEとなる。
ここで(やや天啓的だが)2数の和と差を考えれば、
k +1 -1 0 0
l 0 0 +1 -1
-------------------
k+l +1 -1 +1 -1
k-l +1 -1 -1 +1
独立な2つの ${1,-1}$ の組合せで4通りの移動を区別できる。
この独立に割り当ててよくなるという点が大きい。
座標で和と差を取るといえば、45度回転に相当する。
座標を45度回転させて(かつ整数になるように $\sqrt{2}$ 倍して)考える。
それぞれの方向へのジャンプは以下のようになる
'R': $(X,Y)→(X+D,Y+D)$
'L': $(X,Y)→(X-D,Y-D)$
'U': $(X,Y)→(X-D,Y+D)$
'D': $(X,Y)→(X+D,Y-D)$
目標地点は $(A-B, A+B)$
2つの ${-1,1}$ をとる係数列 $k',l'$ があったとして、
となるようにすればよく、たとえば以下のような $k',l'$ があったら、組合せに 'RLUD' を対応づけ、操作列を構成することが出来る。
k' +1 -1 +1 +1 -1
l' +1 +1 -1 +1 -1
対応する操作列 R U D R L
従って、1次元のDPがそのまま使えて、
この情報を使って合計を $A-B$ にできるように $k'$ を決め、$A+B$ にできるように $l'$ を決めれば、答えが求められる。
高速化
これにて一件落着めでたしめでたし、とはならず、上のDPを計算しようとすると最悪 $O(N^2 D_{max})$ かかる。
制約の値を単純に当てはめると $7.2 \times 10^9$ となり、間に合いそうにない。
今回のDPでは管理する値が真偽値 $0/1$ だけでよいため、bitset化できることを利用する。
Pythonなど多倍長整数が可能な言語では、いつものように整数をbit列として扱うだけでもよい。
j 9876543210
値 0011001101 → 205 として扱う
64個の $j$ の値を1つの64bit整数にまとめることができるので、
(細かい部分は言語と実装に依るだろうけど)大体64倍くらいになることが期待される。
$7.2 \times 10^9 / 64 = 1.125 \times 10^8$ となり、
実行時間制限が長めなこともあってまぁ間に合わなくはないレベルになる。
さらなる高速化
このままのDPでも言語によっては間に合うかも知れないが、以下の改善の余地を残している。
オフセットの必要性
1回の遷移の演算回数
bitsetの桁数
$j$ は負の値も取り得るため、bitsetとして扱うなら $Sum(D)$ だけ前もってオフセットでずらす必要がある。
また、遷移は
j 10 9 8 7 6 5 4 3 2 1 0
DP[0] 0 0 0 0 0 1 0 0 0 0 0
↙ ↘ D1=3
DP[1] 0 0 1 0 0 0 0 0 1 0 0
↙ ↘ ↙ ↘ D2=2
DP[2] 1 0 0 0 1 0 1 0 0 0 1
のように、左右にbit-shiftしたそれぞれをbit-orしたものとなり、1回の遷移で3回のbit演算が行われる。
あと、これだと各時点で奇数indexか偶数indexのどちらは必ず0となるため、ちょっと空間の効率が悪い。
ここで、公式Editorialにあるように以下のように式変形する。
ただし、$k'',l''$ は ${0,1}$ をとる係数列で、$0$ が $k',l'$ の $-1$ に相当する。
こうすると、
左辺の合計値は負になり得ないのでオフセットの必要は無くなる
1回の遷移は1回のbit-shiftと1回のbit-orだけで済むので演算回数が2/3に減る
最大桁数も半分になるので1回ずつのbit演算も高速になる
これらの改善により、もうあと何倍か速くなる。
Python3
n, a, b = map(int, input().split())
ddd = list(map(int, input().split()))
s = sum(ddd)
if s + a + b < 0 or s + a - b < 0 or (s + a + b) % 2 == 1:
print('No')
exit()
dp = [1]
for d in ddd:
dp.append(dp[-1] | (dp[-1] << d))
x = 1 << ((s + a + b) // 2)
y = 1 << ((s + a - b) // 2)
if dp[-1] & x == 0 or dp[-1] & y == 0:
print('No')
exit()
operations = []
for i in range(n - 1, -1, -1):
d = ddd[i]
b = dp[i]
if b & x:
if b & y:
operations.append('L')
else:
operations.append('D')
y >>= d
else:
if b & y:
operations.append('U')
x >>= d
else:
operations.append('R')
x >>= d
y >>= d
operations.reverse()
print('Yes')
print(''.join(operations))
H - Count Multiset
問題
整数 $N,M$ が与えられる
$f(x)$ を、以下の条件を全て満たす多重集合の個数とする
条件
$x$ 個の正整数からなる
総和が $N$
同じ要素の個数は $M$ 個を超えない
$f(1),f(2),...,f(N)$ をそれぞれ $\mod{998244353}$ で求めよ
$1 \le M \le N \le 5000$
解法
$N$ を $x$ 個の自然数の和に分割する方法の個数を示す「分割数」に、同じ数の個数上限が加わった感じの問題となっている。
DPを行うのだが、前から1要素ずつ決めるとかの考え方からの転換が必要。
発想がわかれば、理解・実装はさほど難しくない。
以下、単に「多重集合」といったとき、正整数からなり、同じ要素の個数は $M$ 個を超えないという条件は満たしているとする。
$DP[i][j]$ を求めることを考える。
ここで唐突に、多重集合の全要素を1ずつ減らしてみる。
元の集合での'1'が'0'となり多重集合から除外されるが、他は1減らした値での集合と1対1対応する。
M=2
i=4 j=10
1が0個 {2, 2, 3, 3} → {1, 1, 2, 2} ─ DP[4][6] と一致
1が1個 {1, 2, 2, 5} → {1, 1, 4} ┐
{1, 2, 3, 4} → {1, 2, 3} ┴ DP[3][6] と一致
1が2個 {1, 1, 2, 6} {1, 5} ┐
{1, 1, 3, 5} → {2, 4} ┼ DP[2][6] と一致
{1, 1, 4, 4} {3, 3} ┘
(Mの条件と合致しないためDP[4][10]には計上されないが続きとして)
1が3個 {1, 1, 1, 7} → {6} DP[1][6] と一致
要は、以下のようになる。
'1'が0個: 要素数が $i$、総和が $j-i$ の多重集合の個数と一致する
'1'が1個: 要素数が $i-1$、総和が $j-i$ の多重集合の個数と一致する
'1'が2個: 要素数が $i-2$、総和が $j-i$ の多重集合の個数と一致する
…
'1'がM個: 要素数が $i-M$、総和が $j-i$ の多重集合の個数と一致する
つまり、自分より小さな $i,j$ について $DP[i][j]$ が埋まっていれば、答えはその和で求められる。
さらに参照する $i$ は連続しているので、$i$ 方向に累積和を取った状態で管理しておけば、
以下のように、累積和の差分で求められる。(indexが負の場合の例外処理はするとして)
DPを埋め終わったら、$DP[1][N]~DP[N][N]$ がそれぞれの答え。
初期値としては、要素数 $i=1$ のとき明らかに $j=1~N$ に対して $DP[1][j]=1$ なので
($\{j\}$ 1要素のみからなる多重集合の1通り)、それを埋めておけばよい。
Python3
import os
import sys
import numpy as np
def solve(n, m):
MOD = 998244353
dp = np.ones((n + 2, n + 1), np.int64)
dp[0, 1:] = 0
dp[-1, :] = 0
for i in range(2, n + 1):
for j in range(i):
dp[i, j] = dp[i - 1, j]
for j in range(i, n + 1):
dp[i, j] = (dp[i - 1, j] + dp[i, j - i] - dp[max(-1, i - m - 1), j - i]) % MOD
ans = np.zeros(n, np.int64)
ans[0] = dp[1, n]
for i in range(1, n):
ans[i] = (dp[i + 1, n] - dp[i, n]) % MOD
return ans
SIGNATURE = '(i8,i8)'
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', SIGNATURE)(solve)
cc.compile()
exit()
if os.name == 'posix':
# noinspection PyUnresolvedReferences
from my_module import solve
else:
from numba import njit
solve = njit(SIGNATURE, cache=True)(solve)
print('compiled', file=sys.stderr)
n, m = map(int, input().split())
ans = solve(n, m)
print(*ans)