まず、頂点数1の木は条件を満たすが、ちょっと特殊なので分けて考える。とりあえず N 個ある。
頂点数2以上の条件を満たす木は、以下で作れる。
④ (例) ④
/|\ /|\
④○④ → ④○④
/| /\ \ \
④○○④ ○ ④
それ以外の頂点を繋げると、色④以外の葉ができてしまうので、条件を満たさない。
上記の方法で作れる木の個数を数えればよい。
「じゃあ、④が5個あったら、そこから2個以上選ぶ場合の数は 2^5-1-5=26 通りだ!」というと、そう単純でもない。
元の木で一直線上に並ぶ頂点は、端の頂点などを2つ選んだ場合、間の頂点は選んでも選ばなくても、
できあがる木としては同じになってしまう。
「木としてユニークな」個数を数えなければならない。
ある同じ色の頂点の選び方において、その頂点を選ばないとできあがる木が変わってしまうような頂点を「有効頂点」とする。
④ ←④は選んでも選ばなくても、他が一緒ならできあがる木は一緒
/|\
④○❹ ←❹は有効頂点(できあがる木において葉になる頂点ともいえる)
/ \
❹ ❹
木DPをする。
更新を考える。
DP[v] は v の親に繋がる分のみを管理する。
v 以下で完結するような選び方は、適宜答えを加算していくための変数 ans に足していき、DPには含めない。
:
①←v
/ | \ DP[左の子] = {①: 2}
① ② ② DP[中の子] = {①: 1, ②: 1}
| | DP[右の子] = {②: 1}
① ①
(a) v が次数1の有効頂点となり、子からつながってきたグラフを終了させる
(b) v が次数1の有効頂点となって開始し、そこから親に繋げていく
(c) v は単なる通過点で、1つの子のみからの状態をそのまま親に繋げる。
(d) v の2つ以上の子を結び、親へは繋げず終了する
(e) v の2つ以上の子を結び、また親へも繋げる
(a) v が次数1となり、子からつながるグラフを終了させる場合。
これは、v の色 A_v について、各子の DP[\cdot,A_v] を合計すればよい。
上記例では v の色は①なので、左2 + 中1 = 3通り。
ansに加算し、DP[v] には含めない。
(b) これは単に DP[v,A_v]+=1
(c) は、子を合成すればよい。同じ色同士は加算する。上記例では {①:3, ②:2} となる。
(d),(e) は、子を1つずつマージしていくことで求められる。いずれも同じ値になるので、答えに加算しつつ DP[v] にも加算する。
まとめて、DP[v] および v で完結する分のansへの加算は、以下のアルゴリズムで求められる。
DP[v]=\{\}(空の辞書)で初期化する
左の子(w とする)をマージする。
中の子とのマージは、左の子とマージ済みの DP[v] をもって同じことをすれば求まる
右の子とのマージは、中の子とマージ済みの DP[v] をもって同じことをすれば求まる
最後に(b)より、v から始める分を加算 DP[v,A_v]+=1
これをそのままやるとTLEだが、マージテクにより、
DP[v] と DP[w] の大きい方に小さい方を追加していくと O(N \log{N}) で収まるようになる。
Python3
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from collections import defaultdict
n = int ( input ())
aaa = list ( map ( int , input ().split()))
links = [ set () for _ in range (n)]
for _ in range (n - 1 ):
u, v = map ( int , input ().split())
u - = 1
v - = 1
links[u].add(v)
links[v].add(u)
q = [ 0 ]
dp = [ None ] * n
status = [ 0 ] * n
ans = n
MOD = 998244353
while q:
u = q[ - 1 ]
if status[u] = = 0 :
status[u] = 1
for v in links[u]:
links[v].remove(u)
q.append(v)
continue
q.pop()
a = aaa[u]
dpu = defaultdict( int )
for v in links[u]:
dpv = dp[v]
ans + = dpv[a]
ans % = MOD
if len (dpu) < len (dpv):
dpu, dpv = dpv, dpu
for b, c in dpv.items():
if b in dpu:
tmp = dpu[b] * c
tmp % = MOD
ans + = tmp
ans % = MOD
dpu[b] + = tmp
dpu[b] % = MOD
dpu[b] + = c
dpu[b] % = MOD
dpu[a] + = 1
dp[u] = dpu
ans % = MOD
print (ans)
|