サントリープログラミングコンテスト2023(AtCoder Beginner Contest 321)F,G問題メモ
F - #(subset sum = K) with Add and Erase
問題
解法
減らす場合のナップサック。
問題を見るに、1回だけなら典型的なナップサック問題。
でも1回あたり O(QK) なので、毎回ゼロから答えを求めると、O(Q2K) となり間に合わない。
通常のナップサックは荷物を入れる順番は関係ないので、計算済みの DP から「減らす」操作をしたい。
DP[i,j] に要素 x を追加すると DP[i+1,j+x] に遷移するのだから、
j の上限を K で切らず、全ての場合を持っておけば、DP[i,j+x] から DP[i+1,j] のように減らす場合も計算できるのでは?
だがそれは要素数が O(Qxmax) のレベルで増えていくので無理。
ところが、よく考えると、実は j の上限は K のままで、減らす場合も計算できる。
0 1 ... x ... j ... j+x ... K
i a b ... c d . k ... e ... z 通常のナップサックの更新:
a b ........ k ... xずらして加算
----------------------------------
i+1 a b ... a+c b+d ..... e+k ...
通常のナップサックでの x の追加は、DP[i] 同士を、x ずらして足し込むイメージ。
逆に x を減らす更新を行うとき、「直前で追加したのが x だったとして、x を追加する前の状態」を想像する。
すると、j の添字の小さい方から順に、DP[i,j+x] から DP[i,j] を引いてやると、元の状態が復元できることが分かる。
この操作に、j>K の情報は必要ない。
Python3
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import os
import sys
import numpy as np
def solve(q, k, operators, xxx):
dp = np.zeros(k + 1, np.int64)
dp[0] = 1
MOD = 998244353
ans = np.zeros(q, np.int64)
for i in range(q):
x = xxx[i]
if operators[i] == 0:
for j in range(k, x - 1, -1):
dp[j] += dp[j - x]
else:
for j in range(k - x + 1):
dp[j + x] -= dp[j]
dp %= MOD
ans[i] = dp[k]
return ans
SIGNATURE = '(i8,i8,i8[:],i8[:])'
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', SIGNATURE)(solve)
cc.compile()
exit()
if os.name == 'posix':
from my_module import solve
else:
from numba import njit
solve = njit(SIGNATURE, cache=True)(solve)
print('compiled', file=sys.stderr)
q, k = map(int, input().split())
operators = []
xxx = []
for _ in range(q):
op, x = input().split()
op = '+-'.index(op)
x = int(x)
operators.append(op)
xxx.append(x)
ans = solve(q, k, np.array(operators, np.int64), np.array(xxx, np.int64))
print('\n'.join(map(str, ans)))
|
G - Electric Circuit
問題
N 個の電子基板がある
赤の端子は電子基板 R1,R2,...,RM に、青の端子は B1,B2,...,BM についている(重複あり)
M 本のケーブルで、赤と青の端子を1つずつ M 組のペアにしてつなぐ
つなぎ方は M! 通りあるが、ランダムに選ぶときの連結成分数の期待値を mod998244353 で求めよ
1≤N≤17
1≤M≤105
解法
主客転倒とbitDP
期待値は、M! 通りのつなぎ方全ての連結成分数の総和を求めた上で M! で割ればよいので、総和を求めることを目指す。
また、これは主客転倒して、「ある基板の集合 S に対し、それが1つの連結成分を為すようなケーブルのつなぎ方の個数」
を求めて、その総和を取ることと一致する。
「S の中の赤(青)端子」を k 個、「S の中だけのケーブルのつなぎ方のうち、S が1つの連結成分をなすもの」を x 通りとすると、S の外の端子はどうつなごうと自由なので、S が1つの連結成分を為すつなぎ方は x(N−k)! で計算できる。
で、肝心の「S の中だけのケーブルのつなぎ方のうち、S が1つの連結成分をなすもの」は、
「S の中だけのケーブルの全てのつなぎ方 k!」から、「2つ以上の連結成分に分かれるようなつなぎ方」を除外するとよい。
DPの遷移
いま、1つ S を決めて DP[S] を求めるとする。S の部分集合のDPの値は、S 自身を除いて全て計算済みとする。
赤と青の端子の個数が一致していなければそもそも不可能なので DP[S]=0。以下、一致しているとする。
赤端子の個数を k とすると、まず全体では k! 通りのつなぎ方がある。
そのうち、2つ以上の連結成分に分かれるものを除去するにあたり、
「最も番号が小さい基板が属する連結成分 T」を、S の部分集合(S 自身以外)から総当たりする。
最も番号が小さい基板に限定するのは、重複を防ぐため。
T の端子数を l として、DP[T]×(k−l)! 通りのつなぎ方が除去される。
k! から全てを除去した後に残ったのが、DP[S] となる。
計算量は O(3N)
Python3
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from collections import Counter
def precompute_factorials(n, MOD):
f = 1
factorials = [1]
for m in range(1, n + 1):
f = f * m % MOD
factorials.append(f)
f = pow(f, MOD - 2, MOD)
finvs = [1] * (n + 1)
finvs[n] = f
for m in range(n, 1, -1):
f = f * m % MOD
finvs[m - 1] = f
return factorials, finvs
def iter_sub_bitset(bitset, include_self=True, include_empty=True):
v = bitset
if not include_self:
v = (v - 1) & bitset
while v:
yield v
v = (v - 1) & bitset
if include_empty:
yield 0
n, m = map(int, input().split())
rrr = [r - 1 for r in map(int, input().split())]
bbb = [b - 1 for b in map(int, input().split())]
MOD = 998244353
facts, finvs = precompute_factorials(m, MOD)
r_cnt = Counter(rrr)
b_cnt = Counter(bbb)
r_terminal = [0] * (1 << n)
b_terminal = [0] * (1 << n)
for bitset in range(1, 1 << n):
lsb = bitset & -bitset
i = lsb.bit_length() - 1
r_terminal[bitset] = r_terminal[bitset ^ lsb] + r_cnt[i]
b_terminal[bitset] = b_terminal[bitset ^ lsb] + b_cnt[i]
dp1 = [0] * (1 << n)
dp2 = [0] * (1 << n)
dp1[0] = dp2[0] = 1
for bitset in range(1, 1 << n):
if r_terminal[bitset] != b_terminal[bitset]:
continue
tmp = dp2[bitset] = facts[r_terminal[bitset]]
lsb = bitset & -bitset
remain = bitset ^ lsb
for sub in iter_sub_bitset(remain, include_self=False, include_empty=True):
target = sub | lsb
others = bitset ^ target
tmp -= dp1[target] * dp2[others]
dp1[bitset] = tmp % MOD
ans = 0
full = (1 << n) - 1
for bitset in range(1, 1 << n):
ans += dp1[bitset] * dp2[full ^ bitset]
ans %= MOD
ans *= finvs[m]
ans %= MOD
print(ans)
|