GAMEFREAK Programming Contest 2023 (AtCoder Beginner Contest 317) F,G問題メモ
F - Nim
問題
正整数 $N$ と、$A_1,A_2,A_3$ が与えられる
以下の条件を全て満たす3数 $(X_1,X_2,X_3)$ の組の個数を $\mod{998244353}$ で求めよ
$1 \le N \le 10^{18}$
$1 \le A_i \le 10$
解法
3数の排他的論理和が0なので、2進数で表したときに全ての桁についてbitが「2数でだけ立ってる」か「どれも立ってない」かである。
$A_i$ の制約が小さいので、あまりを持った桁DPができる。
3数同時にあまりを管理しても、$O(\log{N} A_{max}^3)$ で、$6 \times 10^4$ 程度に収まる。
次の桁にbitを3つ中2つだけ立てた場合、どれにも立てなかった場合、に遷移していき、最終的にあまりが全て0のパターン数が答え、という解法が取れそう。
$N$ 以下であることが求められるので、よくある桁DPなら
みたいな方法を取るが今回は3数同時に面倒を見る必要があり、どれか1個だけが $N$ ちょうどとかいう場合も考慮するので、
$N$ ちょうども状態の中に混ぜ込む必要がある。
$N$ を2進数で表して上から $k$ 桁を取ったものを $f(k)$ とするとして、
(例)N = 43 = 101011(2) ⇒ f(4) = 1010(2) = 10
以下のDPを定義する。
$DP[0,A_1,A_2,A_3]=1$ からはじめて、
$N$ を超えないように、また $N$ ちょうどの状態に注意しながら、
3数のうち2個にビットを立てるか、どれにも立てないかでパターンを遷移させていく。
最終的に $DP[N,0,0,0]$ が答えだが、$N$ 自身が $A_1$ で割りきれる場合、$DP[N,A_1,0,0]$ も答えに加算する。
(もし $A_2,A_3$ も割り切れるなら、その組み合わせだけ参照箇所が増える)
最後に、答えには $X_i=0$ の場合も含まれてしまっている。1以上で無ければいけないので、これを除外する。
以上で答えとなる。
Python3
雑に見積もって $10^{54}$ 以下にしかならないので、これくらいならPythonなら逐一割るより、最後に割った方が高速。
from math import gcd
from pprint import pprint
n, a, b, c = map(int, input().split())
l = n.bit_length()
dp = [[[0] * (c + 1) for _ in range(b + 1)] for _ in range(a + 1)]
dp[a][b][c] = 1
da = 0
db = 0
dc = 0
for d in range(l - 1, -1, -1):
ndp = [[[0] * (c + 1) for _ in range(b + 1)] for _ in range(a + 1)]
if n & (1 << d):
da0 = (da << 1) % a
db0 = (db << 1) % b
dc0 = (dc << 1) % c
for i in range(a + 1):
use_a = a if i == a else ((i << 1) + 1) % a
not_a = da0 if i == a else (i << 1) % a
for j in range(b + 1):
use_b = b if j == b else ((j << 1) + 1) % b
not_b = db0 if j == b else (j << 1) % b
for k in range(c + 1):
use_c = c if k == c else ((k << 1) + 1) % c
not_c = dc0 if k == c else (k << 1) % c
ndp[use_a][use_b][not_c] += dp[i][j][k]
ndp[use_a][not_b][use_c] += dp[i][j][k]
ndp[not_a][use_b][use_c] += dp[i][j][k]
ndp[not_a][not_b][not_c] += dp[i][j][k]
da = ((da << 1) + 1) % a
db = ((db << 1) + 1) % b
dc = ((dc << 1) + 1) % c
else:
for i in range(a + 1):
use_a = -1 if i == a else ((i << 1) + 1) % a
not_a = a if i == a else (i << 1) % a
for j in range(b + 1):
use_b = -1 if j == b else ((j << 1) + 1) % b
not_b = b if j == b else (j << 1) % b
for k in range(c + 1):
use_c = -1 if k == c else ((k << 1) + 1) % c
not_c = c if k == c else (k << 1) % c
if use_a != -1 and use_b != -1:
ndp[use_a][use_b][not_c] += dp[i][j][k]
if use_a != -1 and use_c != -1:
ndp[use_a][not_b][use_c] += dp[i][j][k]
if use_b != -1 and use_c != -1:
ndp[not_a][use_b][use_c] += dp[i][j][k]
ndp[not_a][not_b][not_c] += dp[i][j][k]
da = (da << 1) % a
db = (db << 1) % b
dc = (dc << 1) % c
dp = ndp
ans = 0
for i in (0, a) if n % a == 0 else (0,):
for j in (0, b) if n % b == 0 else (0,):
for k in (0, c) if n % c == 0 else (0,):
ans += dp[i][j][k]
ans -= n // (b * c // gcd(b, c))
ans -= n // (a * c // gcd(a, c))
ans -= n // (a * b // gcd(a, b))
ans -= 1
print(ans % 998244353)
G - Rearranging
問題
解法
全てを一度に決めようとすると上手くいかない。
実は、「$1~N$ の数字が $M$ 個ずつ配置されている」という条件を満たしていれば必ず可能であり、
その証明があれば、1列ずつ貪欲に決めていってもいい、という考え方となる。
(入力サンプルにも可能な例しか無かった時点で少し怪しかったね)
いま、一番左端の1列だけをとりあえず決めるとする。
どうやっても $1~N$ を揃えるのは不可能というのはどんな状態か?
二部マッチングにおけるホールの定理を利用すれば、
場合に不可能であり、そうでなければ何らかの完全マッチングは存在する、ということになる。
→ 3 3 3 3 5 4行に存在する数の種類が、あわせても3種類しかない
→ 3 3 5 5 6 みたいな行の集合がある場合に不可能となる
o o o o o
→ 3 3 3 6 6
→ 5 5 6 6 6
だが、これは少し考えると、数のどれかが $M$ 個より多く存在していないと不可能なので、あり得ない。
よって、左端の列に $1~N$ は必ず揃えることができ、1つ列が減った状態に帰結できる。
1つ列が減った状態も、$N$ 行 $m$ 列に対して「$1~N$ が $m$ 個ずつ」という条件は満たされているので、
同じ理屈により再帰的に完全マッチングが存在することが証明できる。
完全マッチングは最大流で求められる。
$N,M$ の制約が小さいので、各行に残っている数を管理しながら $M$ 回、グラフ構築から最大流を流しても十分間に合う。
Python3
from collections import deque
from typing import Tuple, List
class Dinic:
"""
Usage:
mf = Dinic(n)
-> mf.add_link(from, to, capacity)
-> mf.max_flow(source, target)
"""
def __init__(self, n: int):
self.n = n
self.links: List[List[List[int]]] = [[] for _ in range(n)]
# if exists an edge (v→u, capacity)...
# links[v] = [ [ capacity, u, index of rev-edge in links[u], is_original_edge ], ]
def add_link(self, from_: int, to: int, capacity: int) -> None:
self.links[from_].append([capacity, to, len(self.links[to]), 1])
self.links[to].append([0, from_, len(self.links[from_]) - 1, 0])
def bfs(self, s: int) -> List[int]:
depth = [-1] * self.n
depth[s] = 0
q = deque([s])
while q:
v = q.popleft()
for cap, to, rev, _ in self.links[v]:
if cap > 0 and depth[to] < 0:
depth[to] = depth[v] + 1
q.append(to)
return depth
def dfs(self, s: int, t: int, depth: List[int], progress: List[int], link_counts: List[int]) -> int:
links = self.links
stack = [s]
while stack:
v = stack[-1]
if v == t:
break
for i in range(progress[v], link_counts[v]):
progress[v] = i
cap, to, rev, _ = links[v][i]
if cap == 0 or depth[v] >= depth[to] or progress[to] >= link_counts[to]:
continue
stack.append(to)
break
else:
progress[v] += 1
stack.pop()
else:
return 0
f = 1 << 60
fwd_links = []
bwd_links = []
for v in stack[:-1]:
cap, to, rev, _ = link = links[v][progress[v]]
f = min(f, cap)
fwd_links.append(link)
bwd_links.append(links[to][rev])
for link in fwd_links:
link[0] -= f
for link in bwd_links:
link[0] += f
return f
def max_flow(self, s: int, t: int) -> int:
link_counts = list(map(len, self.links))
flow = 0
while True:
depth = self.bfs(s)
if depth[t] < 0:
break
progress = [0] * self.n
current_flow = self.dfs(s, t, depth, progress, link_counts)
while current_flow > 0:
flow += current_flow
current_flow = self.dfs(s, t, depth, progress, link_counts)
return flow
def cut_edges(self, s: int) -> List[Tuple[int, int]]:
""" max_flowしたあと、最小カットにおいてカットすべき辺を復元する """
q = [s]
reachable = [0] * self.n
reachable[s] = 1
while q:
v = q.pop()
for cap, u, li, _ in self.links[v]:
if cap == 0 or reachable[u]:
continue
reachable[u] = 1
q.append(u)
edges = []
for v in range(self.n):
if reachable[v] == 0:
continue
for cap, u, li, orig in self.links[v]:
if orig == 1 and reachable[u] == 0:
edges.append((v, u))
return edges
n, m = map(int, input().split())
s = 2 * n
t = s + 1
ans = [[0] * m for _ in range(n)]
count_by_col = [[0] * n for _ in range(n)] # i 行目にある数字 k の個数
fixed_by_col = [[0] * n for _ in range(n)] # 1列ずつ決める過程で、i 行目に数字 k が埋まった個数
field = []
for i in range(n):
row = list(map(int, input().split()))
for j in range(m):
row[j] -= 1
count_by_col[i][row[j]] += 1
field.append(row)
for j in range(m):
dnc = Dinic(t + 1)
for i in range(n):
for k in range(n):
rem = count_by_col[i][k] - fixed_by_col[i][k]
if rem:
dnc.add_link(i, k + n, 1)
dnc.add_link(s, i, 1)
dnc.add_link(i + n, t, 1)
result = dnc.max_flow(s, t)
if result < n:
print('No')
exit()
for i in range(n):
for link in dnc.links[i]:
if link[3] == 1 and link[0] == 0:
k = link[1] - n
ans[i][j] = k + 1
fixed_by_col[i][k] += 1
print('Yes')
for row in ans:
print(' '.join(map(str, row)))