AtCoder Regular Contest 140 D問題メモ
D - One to One
問題文
全ての要素が 1 以上 N 以下である長さ N の整数列 X=(X1,X2,…,XN) に対して次の問題を考え、その答えを f(X) とします。
長さ N の整数列 A=(A1,A2,…,AN) が与えられます。各 Ai は 1 以上 N 以下の整数あるいは −1 です。
全ての要素が 1 以上 N 以下である長さ N の整数列 X=(X1,X2,…,XN) であって、Ai≠−1⇒Ai=Xi を満たすものを考えます。そのような全ての X に対する f(X) の総和を 998244353 で割ったあまりを求めてください。
制約
解法
無向Functional Graphなので、各連結成分は1つだけ閉路を持つ。
逆に言うと「閉路の個数=連結成分の個数」なので閉路を数えることとしてよい。この問題はその方が考えやすい。
まず、-1 以外の辺をUnionFind等で結合する。(ここでの暫定的な連結成分を、“準連結成分”と呼ぶことにする)
各準連結成分は「①:-1の頂点を含まない、閉路を1つ含む準連結成分(俗称なもりグラフ)」と
「②:-1を1つだけ含む、木状の準連結成分」に分類される。
①の個数を X、②の個数を Y として、②をつなぎに行く先で NY 通りのパターンがある。
①同士はどうやっても最終的に同じ連結成分になることはないので、X×NY 個はもう確定している。
②同士のみをつないだ連結成分が追加で何個できますか、という問題となる。
「1個できるのが○パターン、2個できるのが○パターン、、、」を求めるのは難しいので、主客転倒する。
「閉路を構成するような準連結成分の組」を1つ固定し、それが何個のパターンで計上されるかを、全ての組について足し合わせる。
ここで、②のみからなる「連結成分を構成する準連結成分の組」を固定しても、その個数をうまく計算するのは難しい。
②のみからなる「閉路を構成する組」を固定すると、計算しやすくなる。
後からその閉路に他の②がつなぎに来てもよいが、それは固定の対象としない。
②の準連結成分に順番を付け、頂点数をそれぞれ B1,B2,...,BY とする。
また、閉路を構成する組 I=(i1,i2,...,ik) を1つ固定する。
ⅰ) どの順番でつなぐか: (k−1)!
ⅱ) どの頂点につなぐか: ∏i∈IBi
ⅲ) その他の頂点の決め方: どこにつないでもよい。NY−k
これらをかけあわせた結果が I に対する答えで、考えられる全ての I について総和を取ると全体の答えとなる。
ここで、ⅰとⅲはサイズ k にしか依存しないので、これを基準にまとめることを考えると、
ⅱの総和を k ごとにまとめて計算できればよい。
で、これはDPで O(N2) だったり、畳み込みを使って O(Nlog2N) などで求められる。
Python3
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
class UnionFindWithUnitedCount:
def __init__(self, n):
self.table = [-1] * n
self.count = [0] * n
def root(self, x):
stack = []
tbl = self.table
while tbl[x] >= 0:
stack.append(x)
x = tbl[x]
for y in stack:
tbl[y] = x
return x
def find(self, x, y):
return self.root(x) == self.root(y)
def unite(self, x, y):
r1 = self.root(x)
r2 = self.root(y)
if r1 == r2:
self.count[r1] += 1
return False
d1 = self.table[r1]
d2 = self.table[r2]
if d1 <= d2:
self.table[r2] = r1
self.table[r1] += d2
self.count[r1] += self.count[r2] + 1
else:
self.table[r1] = r2
self.table[r2] += d1
self.count[r2] += self.count[r1] + 1
return True
def get_size(self, x):
return -self.table[self.root(x)]
n = int(input())
aaa = list(map(int, input().split()))
MOD = 998244353
uft = UnionFindWithUnitedCount(n)
for i in range(n):
if aaa[i] == -1:
continue
a = aaa[i] - 1
uft.unite(i, a)
out0_component_count = 0
out1_vertex_counts = []
for i in range(n):
if uft.table[i] < 0:
size = -uft.table[i]
if size > uft.count[i]:
out1_vertex_counts.append(size)
else:
out0_component_count += 1
m = len(out1_vertex_counts)
dp = [0] * (m + 1)
dp[0] = 1
for i, size in enumerate(out1_vertex_counts):
for j in range(i, -1, -1):
dp[j + 1] += dp[j] * size
dp[j + 1] %= MOD
facts = [1, 1]
for i in range(2, m):
facts.append(facts[-1] * i % MOD)
ans = pow(n, m, MOD) * out0_component_count % MOD
for i in range(1, m + 1):
ans += facts[i - 1] * dp[i] * pow(n, m - i, MOD)
ans %= MOD
print(ans)
|