サントリープログラミングコンテスト2023(AtCoder Beginner Contest 321)F,G問題メモ
F - #(subset sum = K) with Add and Erase
問題
解法
減らす場合のナップサック。
問題を見るに、1回だけなら典型的なナップサック問題。
でも1回あたり O(QK) なので、毎回ゼロから答えを求めると、O(Q^2K) となり間に合わない。
通常のナップサックは荷物を入れる順番は関係ないので、計算済みの DP から「減らす」操作をしたい。
DP[i,j] に要素 x を追加すると DP[i+1,j+x] に遷移するのだから、
j の上限を K で切らず、全ての場合を持っておけば、DP[i,j+x] から DP[i+1,j] のように減らす場合も計算できるのでは?
だがそれは要素数が O(Qx_{max}) のレベルで増えていくので無理。
ところが、よく考えると、実は j の上限は K のままで、減らす場合も計算できる。
0 1 ... x ... j ... j+x ... K
i a b ... c d . k ... e ... z 通常のナップサックの更新:
a b ........ k ... xずらして加算
----------------------------------
i+1 a b ... a+c b+d ..... e+k ...
通常のナップサックでの x の追加は、DP[i] 同士を、x ずらして足し込むイメージ。
逆に x を減らす更新を行うとき、「直前で追加したのが x だったとして、x を追加する前の状態」を想像する。
すると、j の添字の小さい方から順に、DP[i,j+x] から DP[i,j] を引いてやると、元の状態が復元できることが分かる。
この操作に、j \gt K の情報は必要ない。
Python3
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import os
import sys
import numpy as np
def solve(q, k, operators, xxx):
dp = np.zeros(k + 1 , np.int64)
dp[ 0 ] = 1
MOD = 998244353
ans = np.zeros(q, np.int64)
for i in range (q):
x = xxx[i]
if operators[i] = = 0 :
for j in range (k, x - 1 , - 1 ):
dp[j] + = dp[j - x]
else :
for j in range (k - x + 1 ):
dp[j + x] - = dp[j]
dp % = MOD
ans[i] = dp[k]
return ans
SIGNATURE = '(i8,i8,i8[:],i8[:])'
if sys.argv[ - 1 ] = = 'ONLINE_JUDGE' :
from numba.pycc import CC
cc = CC( 'my_module' )
cc.export( 'solve' , SIGNATURE)(solve)
cc. compile ()
exit()
if os.name = = 'posix' :
from my_module import solve
else :
from numba import njit
solve = njit(SIGNATURE, cache = True )(solve)
print ( 'compiled' , file = sys.stderr)
q, k = map ( int , input ().split())
operators = []
xxx = []
for _ in range (q):
op, x = input ().split()
op = '+-' .index(op)
x = int (x)
operators.append(op)
xxx.append(x)
ans = solve(q, k, np.array(operators, np.int64), np.array(xxx, np.int64))
print ( '\n' .join( map ( str , ans)))
|
G - Electric Circuit
問題
N 個の電子基板がある
赤の端子は電子基板 R_1,R_2,...,R_M に、青の端子は B_1,B_2,...,B_M についている(重複あり)
M 本のケーブルで、赤と青の端子を1つずつ M 組のペアにしてつなぐ
つなぎ方は M! 通りあるが、ランダムに選ぶときの連結成分数の期待値を \mod{998244353} で求めよ
1 \le N \le 17
1 \le M \le 10^5
解法
主客転倒とbitDP
期待値は、M! 通りのつなぎ方全ての連結成分数の総和を求めた上で M! で割ればよいので、総和を求めることを目指す。
また、これは主客転倒して、「ある基板の集合 S に対し、それが1つの連結成分を為すようなケーブルのつなぎ方の個数」
を求めて、その総和を取ることと一致する。
「S の中の赤(青)端子」を k 個、「S の中だけのケーブルのつなぎ方のうち、S が1つの連結成分をなすもの」を x 通りとすると、S の外の端子はどうつなごうと自由なので、S が1つの連結成分を為すつなぎ方は x(N-k)! で計算できる。
で、肝心の「S の中だけのケーブルのつなぎ方のうち、S が1つの連結成分をなすもの」は、
「S の中だけのケーブルの全てのつなぎ方 k!」から、「2つ以上の連結成分に分かれるようなつなぎ方」を除外するとよい。
DPの遷移
いま、1つ S を決めて DP[S] を求めるとする。S の部分集合のDPの値は、S 自身を除いて全て計算済みとする。
赤と青の端子の個数が一致していなければそもそも不可能なので DP[S]=0。以下、一致しているとする。
赤端子の個数を k とすると、まず全体では k! 通りのつなぎ方がある。
そのうち、2つ以上の連結成分に分かれるものを除去するにあたり、
「最も番号が小さい基板が属する連結成分 T」を、S の部分集合(S 自身以外)から総当たりする。
最も番号が小さい基板に限定するのは、重複を防ぐため。
T の端子数を l として、DP[T] \times (k-l)! 通りのつなぎ方が除去される。
k! から全てを除去した後に残ったのが、DP[S] となる。
計算量は O(3^N)
Python3
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from collections import Counter
def precompute_factorials(n, MOD):
f = 1
factorials = [ 1 ]
for m in range ( 1 , n + 1 ):
f = f * m % MOD
factorials.append(f)
f = pow (f, MOD - 2 , MOD)
finvs = [ 1 ] * (n + 1 )
finvs[n] = f
for m in range (n, 1 , - 1 ):
f = f * m % MOD
finvs[m - 1 ] = f
return factorials, finvs
def iter_sub_bitset(bitset, include_self = True , include_empty = True ):
v = bitset
if not include_self:
v = (v - 1 ) & bitset
while v:
yield v
v = (v - 1 ) & bitset
if include_empty:
yield 0
n, m = map ( int , input ().split())
rrr = [r - 1 for r in map ( int , input ().split())]
bbb = [b - 1 for b in map ( int , input ().split())]
MOD = 998244353
facts, finvs = precompute_factorials(m, MOD)
r_cnt = Counter(rrr)
b_cnt = Counter(bbb)
r_terminal = [ 0 ] * ( 1 << n)
b_terminal = [ 0 ] * ( 1 << n)
for bitset in range ( 1 , 1 << n):
lsb = bitset & - bitset
i = lsb.bit_length() - 1
r_terminal[bitset] = r_terminal[bitset ^ lsb] + r_cnt[i]
b_terminal[bitset] = b_terminal[bitset ^ lsb] + b_cnt[i]
dp1 = [ 0 ] * ( 1 << n)
dp2 = [ 0 ] * ( 1 << n)
dp1[ 0 ] = dp2[ 0 ] = 1
for bitset in range ( 1 , 1 << n):
if r_terminal[bitset] ! = b_terminal[bitset]:
continue
tmp = dp2[bitset] = facts[r_terminal[bitset]]
lsb = bitset & - bitset
remain = bitset ^ lsb
for sub in iter_sub_bitset(remain, include_self = False , include_empty = True ):
target = sub | lsb
others = bitset ^ target
tmp - = dp1[target] * dp2[others]
dp1[bitset] = tmp % MOD
ans = 0
full = ( 1 << n) - 1
for bitset in range ( 1 , 1 << n):
ans + = dp1[bitset] * dp2[full ^ bitset]
ans % = MOD
ans * = finvs[m]
ans % = MOD
print (ans)
|