サントリープログラミングコンテスト2023(AtCoder Beginner Contest 321)F,G問題メモ
F - #(subset sum = K) with Add and Erase
問題
解法
減らす場合のナップサック。
問題を見るに、1回だけなら典型的なナップサック問題。
でも1回あたり $O(QK)$ なので、毎回ゼロから答えを求めると、$O(Q^2K)$ となり間に合わない。
通常のナップサックは荷物を入れる順番は関係ないので、計算済みの $DP$ から「減らす」操作をしたい。
$DP[i,j]$ に要素 $x$ を追加すると $DP[i+1,j+x]$ に遷移するのだから、
$j$ の上限を $K$ で切らず、全ての場合を持っておけば、$DP[i,j+x]$ から $DP[i+1,j]$ のように減らす場合も計算できるのでは?
だがそれは要素数が $O(Qx_{max})$ のレベルで増えていくので無理。
ところが、よく考えると、実は $j$ の上限は $K$ のままで、減らす場合も計算できる。
0 1 ... x ... j ... j+x ... K
i a b ... c d . k ... e ... z 通常のナップサックの更新:
a b ........ k ... xずらして加算
----------------------------------
i+1 a b ... a+c b+d ..... e+k ...
通常のナップサックでの $x$ の追加は、$DP[i]$ 同士を、$x$ ずらして足し込むイメージ。
逆に $x$ を減らす更新を行うとき、「直前で追加したのが $x$ だったとして、$x$ を追加する前の状態」を想像する。
すると、$j$ の添字の小さい方から順に、$DP[i,j+x]$ から $DP[i,j]$ を引いてやると、元の状態が復元できることが分かる。
この操作に、$j \gt K$ の情報は必要ない。
Python3
import os
import sys
import numpy as np
def solve(q, k, operators, xxx):
dp = np.zeros(k + 1, np.int64)
dp[0] = 1
MOD = 998244353
ans = np.zeros(q, np.int64)
for i in range(q):
x = xxx[i]
if operators[i] == 0:
for j in range(k, x - 1, -1):
dp[j] += dp[j - x]
else:
for j in range(k - x + 1):
dp[j + x] -= dp[j]
dp %= MOD
ans[i] = dp[k]
return ans
SIGNATURE = '(i8,i8,i8[:],i8[:])'
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', SIGNATURE)(solve)
cc.compile()
exit()
if os.name == 'posix':
# noinspection PyUnresolvedReferences
from my_module import solve
else:
from numba import njit
solve = njit(SIGNATURE, cache=True)(solve)
print('compiled', file=sys.stderr)
q, k = map(int, input().split())
operators = []
xxx = []
for _ in range(q):
op, x = input().split()
op = '+-'.index(op)
x = int(x)
operators.append(op)
xxx.append(x)
ans = solve(q, k, np.array(operators, np.int64), np.array(xxx, np.int64))
print('\n'.join(map(str, ans)))
G - Electric Circuit
問題
$N$ 個の電子基板がある
赤の端子は電子基板 $R_1,R_2,...,R_M$ に、青の端子は $B_1,B_2,...,B_M$ についている(重複あり)
$M$ 本のケーブルで、赤と青の端子を1つずつ $M$ 組のペアにしてつなぐ
つなぎ方は $M!$ 通りあるが、ランダムに選ぶときの連結成分数の期待値を $\mod{998244353}$ で求めよ
$1 \le N \le 17$
$1 \le M \le 10^5$
解法
主客転倒とbitDP
期待値は、$M!$ 通りのつなぎ方全ての連結成分数の総和を求めた上で $M!$ で割ればよいので、総和を求めることを目指す。
また、これは主客転倒して、「ある基板の集合 $S$ に対し、それが1つの連結成分を為すようなケーブルのつなぎ方の個数」
を求めて、その総和を取ることと一致する。
「$S$ の中の赤(青)端子」を $k$ 個、「$S$ の中だけのケーブルのつなぎ方のうち、$S$ が1つの連結成分をなすもの」を $x$ 通りとすると、$S$ の外の端子はどうつなごうと自由なので、$S$ が1つの連結成分を為すつなぎ方は $x(N-k)!$ で計算できる。
で、肝心の「$S$ の中だけのケーブルのつなぎ方のうち、$S$ が1つの連結成分をなすもの」は、
「$S$ の中だけのケーブルの全てのつなぎ方 $k!$」から、「2つ以上の連結成分に分かれるようなつなぎ方」を除外するとよい。
DPの遷移
いま、1つ $S$ を決めて $DP[S]$ を求めるとする。$S$ の部分集合のDPの値は、$S$ 自身を除いて全て計算済みとする。
赤と青の端子の個数が一致していなければそもそも不可能なので $DP[S]=0$。以下、一致しているとする。
赤端子の個数を $k$ とすると、まず全体では $k!$ 通りのつなぎ方がある。
そのうち、2つ以上の連結成分に分かれるものを除去するにあたり、
「最も番号が小さい基板が属する連結成分 $T$」を、$S$ の部分集合($S$ 自身以外)から総当たりする。
最も番号が小さい基板に限定するのは、重複を防ぐため。
$T$ の端子数を $l$ として、$DP[T] \times (k-l)!$ 通りのつなぎ方が除去される。
$k!$ から全てを除去した後に残ったのが、$DP[S]$ となる。
計算量は $O(3^N)$
Python3
from collections import Counter
def precompute_factorials(n, MOD):
f = 1
factorials = [1]
for m in range(1, n + 1):
f = f * m % MOD
factorials.append(f)
f = pow(f, MOD - 2, MOD)
finvs = [1] * (n + 1)
finvs[n] = f
for m in range(n, 1, -1):
f = f * m % MOD
finvs[m - 1] = f
return factorials, finvs
def iter_sub_bitset(bitset, include_self=True, include_empty=True):
v = bitset
if not include_self:
v = (v - 1) & bitset
while v:
yield v
v = (v - 1) & bitset
if include_empty:
yield 0
n, m = map(int, input().split())
rrr = [r - 1 for r in map(int, input().split())]
bbb = [b - 1 for b in map(int, input().split())]
MOD = 998244353
facts, finvs = precompute_factorials(m, MOD)
r_cnt = Counter(rrr)
b_cnt = Counter(bbb)
# 集合毎の端子の個数
r_terminal = [0] * (1 << n)
b_terminal = [0] * (1 << n)
for bitset in range(1, 1 << n):
lsb = bitset & -bitset
i = lsb.bit_length() - 1
r_terminal[bitset] = r_terminal[bitset ^ lsb] + r_cnt[i]
b_terminal[bitset] = b_terminal[bitset ^ lsb] + b_cnt[i]
dp1 = [0] * (1 << n) # dp[S] = 集合 S の頂点だけで、端子を使い切り、ちょうど1つの連結成分をなす個数
dp2 = [0] * (1 << n) # dp[S] = 集合 S の頂点だけで、端子を使い切る個数(連結成分は1つとは限らない)
dp1[0] = dp2[0] = 1
for bitset in range(1, 1 << n):
if r_terminal[bitset] != b_terminal[bitset]:
continue
tmp = dp2[bitset] = facts[r_terminal[bitset]]
lsb = bitset & -bitset
remain = bitset ^ lsb
for sub in iter_sub_bitset(remain, include_self=False, include_empty=True):
target = sub | lsb
others = bitset ^ target
tmp -= dp1[target] * dp2[others]
# print(f'{bitset=:05b} {remain=:05b} {sub=:05b} {target=:05b} {others=:05b} {dp1[target] * dp2[others]}')
dp1[bitset] = tmp % MOD
ans = 0
full = (1 << n) - 1
for bitset in range(1, 1 << n):
ans += dp1[bitset] * dp2[full ^ bitset]
ans %= MOD
ans *= finvs[m]
ans %= MOD
print(ans)