区間DP。
$DP1[l][r]=P_l,P_{l+1},...,P_{r-1}$ の $[l,r)$ の区間で同じことを考えて、1つの部分木は何通りあるか?
$DP2[l][r]=P_l,P_{l+1},...,P_{r-1}$ の $[l,r)$ の区間で、木が複数本あってもよい場合は何通りあるか?
P[l,r) = (5 3 7 2 4 8)
DP1での管理対象例:
5 5
/ \ / \
3 7 3 4
/|\ | |
2 4 8 7 8
|
2
DP2での管理対象例: 上記に加え、
5 7 ←ただし根は番号昇順に並んでいること
| |
3 2
/ \
4 8
$P$ の添字を $P_0,P_1,...,P_{N-1}$ に置き換えて考えるとして、$DP1[0][N]$ が答えとなる。
$DP1[l][r],DP2[l][r]$ を考えるとき、それより短い区間は全て計算済みとして、遷移は
$P_l$ を根とした1つの木を作る場合:
複数の木からなる森を作る場合:
最初の木がどこまで続くか $k$ を全探索
$[l,k)$ まで1つの木、$[k,r)$ は複数の木
$P_l \lt P_k$ という条件ときのみ、そこで区切ってくっつけられる
$\displaystyle DP2[l][r]=\sum_{k=l+1}^{r-1}DP1[l][k] \times DP2[k][r] \ if \ P_l \lt P_k$
これで区間が短い方から計算していける。
初期値は、各 $i=0,...,N-1$ につき $DP1[i][i+1]=DP2[i][i+1]=1$ としておけばよい。
で、このまま解いても別にいいのだが、上記の遷移からもわかるとおり、$DP1[l][r]$ は $DP2[l+1][r]$ と同じとなる。
従って、実際に管理するのは $DP2$ のみでよく、$DP2[1][N]$ が答えとなる。
Python3
import os
import sys
import numpy as np
def solve(inp):
n = inp[0]
ppp = inp[1:]
MOD = 998244353
dp = np.zeros((n + 1, n + 1), dtype=np.int64) # [l,r) から構築できる木
for i in range(n):
dp[i, i] = dp[i, i + 1] = 1
for w in range(2, n + 1):
for l in range(n - w + 1):
r = l + w
tmp = dp[l + 1, r]
for k in range(l + 1, r):
if ppp[l] < ppp[k]:
tmp += dp[l + 1, k] * dp[k, r]
tmp %= MOD
dp[l, r] = tmp
return dp[1, n]
SIGNATURE = '(i8[:],)'
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', SIGNATURE)(solve)
cc.compile()
exit()
if os.name == 'posix':
# noinspection PyUnresolvedReferences
from my_module import solve
else:
from numba import njit
solve = njit(SIGNATURE, cache=True)(solve)
print('compiled', file=sys.stderr)
inp = np.fromstring(sys.stdin.read(), dtype=np.int64, sep=' ')
ans = solve(inp)
print(ans)
$N=70$ という、十分小さいが愚直には大きすぎる中途半端な制約。
もし選び方を全探索するとすると、色が $1,2,...$ である宝石の個数を $L_1,L_2,...$ として、
選び方の個数はこれらの総積となる。
しかし、$L_i$ の総和の上限が70である。
最大でもそこまで大きくならないのでは? ということで、上限を考える。
和が一定の場合の総積は、でかい値が1個あるよりは、2とか3とかがいっぱいあった方が、明らかにでかくなる。
具体的には、
$1$ は総積を増やすのに全く寄与しないため、無い方がよい
$L_i \ge 4$ のとき、$2$ と $L_i-2$ に分割した方がでかくなる(小さくならない)
$2$ が3個と $3$ が2個では、$3$ が2個の方が積がでかくなる
よって、$L_i$ の総和が一定の時、一番総積がでかくなるのは、$3$ がなるべくいっぱいあるとき
ということになる。
1とか2とかでなく、3という値に意味があるのって、なんか珍しい気がする。
(なお、整数でなく実数で、総和が一定の上で総積を最大化させるのは、ネイピア数 $e$ をたくさん作るのがいいらしい。参考)
よって、全探索の上限は $2^2 3^{22}=1.25 \times 10^{11}$ となる。
もちろん、これはまだ多すぎるが、半分こすれば $3.5 \times 10^5$ と、十分可能な量となる。
綺麗に半分こできるとは限らないが、そんなに大きくならないことが期待できる。
半分全列挙の考え方を用いて、
互いに同程度となるように色を2グループに振り分け、各グループだけで作れるXORを全て計算しておく。
2つの、要素数 $3.5 \times 10^5$ くらいの多重集合ができた。(グループA,Bとする)
あとは、互いから1つずつ取り出してXORして作れる値から、$K$ 番目に大きいものを求めればよい。
だが、この方法もなかなか難しい。
TLEとなるが、基本的な考え方をまず示す。
両グループの値をそれぞれBinary Trieに登録する。
各ノードに「自身の部分木に含まれる要素数」を持たせておく。
A: (0101, 0111, 1101, 1111) B: (0001, 0100, 1110)
④ ③
0/ \1 0/ \1
② ② ② ①
\1 \1 0/ \1 \1
② ② ① ① ①
0/\1 0/\1 0/ 0/ \1
① ① ① ① ① ① ①
\1 \1 \1 \1 \1 0/ 0/
① ① ① ① ① ① ①
答えを二分探索する。
「XORの結果、$X$ より大きい値はいくつあるか?」は、1回につき、2つのTrieの小さい方のノード数程度で抑えられる。
$f($ A側で着目中のノード $a$, B側で着目中のノード $b$, 着目中のbit $d$, 境界値 $X)$ として、
$a,b$ の各ノード以下に登録された値同士のXORで $X$ 以上になるものの個数を返す関数を考える。
$d$ は大きい方から考えるとして、$X$ で $d$ が立っているなら、、、
$X$ で $d$ が立っていないなら、、、
これで $f($ AのTrieの根, BのTrieの根, 1«59, X)$ とすると、全体での答えが求められる。
$X$ を二分探索し、「$X$ 以上になるXORが $K$ 個以上になる、最も小さい値」が答えとなる。
計算量は、まずBinaryTrieへの登録が、$O(3^{\frac{N}{6}}) \propto 3.5 \times 10^5$ 個の値で
$O(\log{V_{max}}) \propto 60$ 個のノードを作るので、だいたい $10^7$ に数倍をかけたくらい。
次に二分探索が、最悪の場合は全ノードが探索されうるので、1回あたり $O(3^{\frac{N}{6}}\log{V_{max}})$ なので、
答えまで行き着くには $O(3^{\frac{N}{6}}\log{V_{max}}\log{K})$ で、$10^8~10^{10}$ くらいとなり、ちょっと厳しい。
まぁ、実際にほぼ全ノードが探索されうるような $X$ はあったとしても限られるので、
全ての二分探索の試行がここまでかかるわけではないが、実際にTLEだった。
logを1つにする方法もある(公式Editorial参照)が、小手先の高速化を試みる。
結構、BinaryTrieの構築の時点で、それなりの計算量がかかっている。
最初にグループA,Bの全ての値をTrieに登録しきってしまうことはせず、遅延させることを考える。
ノードに個数で無く、値の多重集合自体を持たせておき、二分探索で必要とされたときに初めて(まだ作ってなければ)子を作る。
二分探索開始前のグループAのTrieは、根にグループAの多重集合だけ持った状態
(0101,0111,1101,1111) ←根
二分探索で仮に $X=0110$ などが試されたとして、$f(a,b,d,X)$ による再帰で次の $a,b$ として必要になる箇所だけ、作られてなければその時に作る。
(0101,0111,1101,1111)
0/ \1
(0101, 0111) (1101, 1111)
\1
(0101,0111)
0/ \1
(0101,) (0111,)
\1
(0111,)
$f(a,b,d,X)$ は、$a,b$ セットで再帰していくので、Bの側のTrieに次の $b$ に該当するノードが存在しなかったりした場合は、
$a$ の側でそれ以下を作っていることが無駄だったりする($X$ が変わったら改めて必要になることはありえるが)ので、
この無駄を省くことで一定の高速化になる。
また、$d$ が深く(下の桁に)なるにつれ、そのノード以下にある値の個数は少なくなり、
よほど人為的に作らない限りは $60$ bit目まで降りなくても早々に1などになる。
$f(a,b,d,X)$ を求めるとき、$|a| \times |b|$ が100~1000など、
全探索できる程度に十分に小さければ、その時点で全探索して、再帰せず直接求めてしまえばよい。
このような高速化を図ることで、2秒ギリギリではあるが、通る。
Python3
import os
import sys
import numpy as np
def solve(inp):
n, c, k = inp[:3]
ddd = inp[3::2] - 1
vvv = inp[4::2]
if c == 1:
vvv = np.sort(vvv)[::-1]
return vvv[k - 1]
int_list = [n]
int_list.clear()
gems = [int_list.copy() for _ in range(c)]
for i in range(n):
d = ddd[i]
v = vvv[i]
gems[d].append(v)
gems_counts = np.array([len(g) for g in gems], np.int64)
all_prod = 1
can_prods = [{np.int64(1)}]
for gc in gems_counts:
all_prod *= gc
can_prod = can_prods[-1].copy()
for p in can_prods[-1]:
can_prod.add(p * gc)
can_prods.append(can_prod)
can_prod_s = sorted(can_prods[-1])
tgt = can_prod_s[len(can_prod_s) // 2]
group0 = []
group1 = []
for i in range(c - 1, -1, -1):
g = gems[i]
gc = gems_counts[i]
if tgt % gc == 0 and tgt // gc in can_prods[i]:
group0.append(g)
tgt //= gc
else:
group1.append(g)
def all_pattern(group):
res = [0]
for g in group:
new_res = []
for v in g:
for u in res:
new_res.append(v ^ u)
res = new_res
return res
pat0 = all_pattern(group0)
pat1 = all_pattern(group1)
lll0 = [-1]
rrr0 = [-1]
prp0 = [False]
num0 = [pat0]
cnt0 = [len(pat0)]
lll1 = [-1]
rrr1 = [-1]
prp1 = [False]
num1 = [pat1]
cnt1 = [len(pat1)]
def propagate(lll, rrr, prp, num, cnt, i, d):
l_num = []
r_num = []
for v in num[i]:
if v & d:
r_num.append(v)
else:
l_num.append(v)
if len(l_num) > 0:
lll[i] = len(lll)
lll.append(-1)
rrr.append(-1)
prp.append(False)
num.append(l_num)
cnt.append(len(l_num))
if len(r_num) > 0:
rrr[i] = len(lll)
lll.append(-1)
rrr.append(-1)
prp.append(False)
num.append(r_num)
cnt.append(len(r_num))
prp[i] = True
def check(m):
q = [(0, 0, 1 << 59)]
res = 0
while q:
i0, i1, d = q.pop()
if d == 0:
res += cnt0[i0] * cnt1[i1]
continue
if cnt0[i0] * cnt1[i1] < 1000:
for v0 in num0[i0]:
for v1 in num1[i1]:
if v0 ^ v1 >= m:
res += 1
continue
if not prp0[i0]:
propagate(lll0, rrr0, prp0, num0, cnt0, i0, d)
if not prp1[i1]:
propagate(lll1, rrr1, prp1, num1, cnt1, i1, d)
if d & m:
if lll0[i0] != -1 and rrr1[i1] != -1:
q.append((lll0[i0], rrr1[i1], d >> 1))
if rrr0[i0] != -1 and lll1[i1] != -1:
q.append((rrr0[i0], lll1[i1], d >> 1))
else:
if lll0[i0] != -1 and rrr1[i1] != -1:
res += cnt0[lll0[i0]] * cnt1[rrr1[i1]]
if rrr0[i0] != -1 and lll1[i1] != -1:
res += cnt0[rrr0[i0]] * cnt1[lll1[i1]]
if lll0[i0] != -1 and lll1[i1] != -1:
q.append((lll0[i0], lll1[i1], d >> 1))
if rrr0[i0] != -1 and rrr1[i1] != -1:
q.append((rrr0[i0], rrr1[i1], d >> 1))
if res >= k:
return k
d >>= 1
return res
l = 0
r = 1 << 60
while l + 1 < r:
m = (l + r) // 2
if check(m) < k:
r = m
else:
l = m
return l
SIGNATURE = '(i8[:],)'
if sys.argv[-1] == 'ONLINE_JUDGE':
from numba.pycc import CC
cc = CC('my_module')
cc.export('solve', SIGNATURE)(solve)
cc.compile()
exit()
if os.name == 'posix':
# noinspection PyUnresolvedReferences
from my_module import solve
else:
from numba import njit
solve = njit(SIGNATURE, cache=True)(solve)
print('compiled', file=sys.stderr)
inp = np.fromstring(sys.stdin.read(), dtype=np.int64, sep=' ')
ans = solve(inp)
print(ans)