無向Functional Graphなので、各連結成分は1つだけ閉路を持つ。
逆に言うと「閉路の個数=連結成分の個数」なので閉路を数えることとしてよい。この問題はその方が考えやすい。
まず、-1 以外の辺をUnionFind等で結合する。(ここでの暫定的な連結成分を、“準連結成分”と呼ぶことにする)
各準連結成分は「①:-1の頂点を含まない、閉路を1つ含む準連結成分(俗称なもりグラフ)」と
「②:-1を1つだけ含む、木状の準連結成分」に分類される。
①の個数を $X$、②の個数を $Y$ として、②をつなぎに行く先で $N^Y$ 通りのパターンがある。
①同士はどうやっても最終的に同じ連結成分になることはないので、$X \times N^Y$ 個はもう確定している。
②同士のみをつないだ連結成分が追加で何個できますか、という問題となる。
「1個できるのが○パターン、2個できるのが○パターン、、、」を求めるのは難しいので、主客転倒する。
「閉路を構成するような準連結成分の組」を1つ固定し、それが何個のパターンで計上されるかを、全ての組について足し合わせる。
ここで、②のみからなる「連結成分を構成する準連結成分の組」を固定しても、その個数をうまく計算するのは難しい。
②のみからなる「閉路を構成する組」を固定すると、計算しやすくなる。
後からその閉路に他の②がつなぎに来てもよいが、それは固定の対象としない。
②の準連結成分に順番を付け、頂点数をそれぞれ $B_1,B_2,...,B_Y$ とする。
また、閉路を構成する組 $I=(i_1,i_2,...,i_k)$ を1つ固定する。
これらをかけあわせた結果が $I$ に対する答えで、考えられる全ての $I$ について総和を取ると全体の答えとなる。
ここで、ⅰとⅲはサイズ $k$ にしか依存しないので、これを基準にまとめることを考えると、
ⅱの総和を $k$ ごとにまとめて計算できればよい。
で、これはDPで $O(N^2)$ だったり、畳み込みを使って $O(N \log^2{N})$ などで求められる。
Python3
class UnionFindWithUnitedCount:
def __init__(self, n):
self.table = [-1] * n
self.count = [0] * n
def root(self, x):
stack = []
tbl = self.table
while tbl[x] >= 0:
stack.append(x)
x = tbl[x]
for y in stack:
tbl[y] = x
return x
def find(self, x, y):
return self.root(x) == self.root(y)
def unite(self, x, y):
r1 = self.root(x)
r2 = self.root(y)
if r1 == r2:
self.count[r1] += 1
return False
d1 = self.table[r1]
d2 = self.table[r2]
if d1 <= d2:
self.table[r2] = r1
self.table[r1] += d2
self.count[r1] += self.count[r2] + 1
else:
self.table[r1] = r2
self.table[r2] += d1
self.count[r2] += self.count[r1] + 1
return True
def get_size(self, x):
return -self.table[self.root(x)]
n = int(input())
aaa = list(map(int, input().split()))
MOD = 998244353
uft = UnionFindWithUnitedCount(n)
for i in range(n):
if aaa[i] == -1:
continue
a = aaa[i] - 1
uft.unite(i, a)
out0_component_count = 0
out1_vertex_counts = []
for i in range(n):
if uft.table[i] < 0:
size = -uft.table[i]
if size > uft.count[i]:
out1_vertex_counts.append(size)
else:
out0_component_count += 1
m = len(out1_vertex_counts)
dp = [0] * (m + 1)
dp[0] = 1
for i, size in enumerate(out1_vertex_counts):
for j in range(i, -1, -1):
dp[j + 1] += dp[j] * size
dp[j + 1] %= MOD
facts = [1, 1]
for i in range(2, m):
facts.append(facts[-1] * i % MOD)
ans = pow(n, m, MOD) * out0_component_count % MOD
for i in range(1, m + 1):
ans += facts[i - 1] * dp[i] * pow(n, m - i, MOD)
ans %= MOD
print(ans)