むずかった。前半数問は、大雑把なオーダー記法による計算量見積もりだけでなく、それをどこまで削減できるかの感覚が問われているように感じた。
$j$ を決め打つ。選べる $i$ と $k$ の個数は $j$ よりそれぞれ前と後ろにある $A_j$ より小さい数の個数である。
そしてこの2つは独立に選んでよいので、3つの組としては単純に掛け合わせた分だけ存在する。
各 $j$ につき前と後ろにある自分より小さい数を掛け合わせた数を求め、それらを合計したものが答え。
転倒数を数える要領でBinary Indexed Tree等を使ってもいいし(ただし座標圧縮が必要)、$N$ が小さいので各 $j$ の両側を毎回探索してもいい。
class Bit: def __init__(self, n): self.size = n self.tree = [0] * (n + 1) def sum(self, i): s = 0 while i > 0: s += self.tree[i] i -= i & -i return s def add(self, i, x): while i <= self.size: self.tree[i] += x i += i & -i n = int(input()) aaa = list(map(int, input().split())) bbb = sorted(set(aaa)) ccc = {b: c for c, b in enumerate(bbb)} m = len(bbb) bit_f = Bit(m) bit_b = Bit(m) smaller_f = [] smaller_b = [] for a in aaa: c = ccc[a] bit_f.add(c + 1, 1) smaller_f.append(bit_f.sum(c)) for a in reversed(aaa): c = ccc[a] bit_b.add(c + 1, 1) smaller_b.append(bit_b.sum(c)) smaller_b.reverse() ans = 0 for i in range(1, n - 1): ans += smaller_f[i] * smaller_b[i] print(ans)
いろいろな解き方があるっぽいが、問題文より条件を1つ1つプログラムに落とし込んでいく。
まず $S_2=S_6$ のため、$S_6$ の末尾文字が $S_2$ の末尾文字としてそれより4文字以上前に含まれていないといけないので、それを全探索で決め打つ。
これにより、文字を $S_1~S_2$ と $S_3~S_6$ に分割する箇所が固定される。
aabbbbccbb ↓ aabbb bccbb aabb bbccbb aab bbbccbb aabbbbccb b は不可。S2とS6の間に3文字以上なく、S3,S4,S5を取れる場所が無い
次に $S_6$ の長さを決め打ち、$S_2=S_6$ となるか確認する。この時、$S_1$ が取れるよう左側に1文字、$S_3~S_5$ が取れるよう右側に3文字以上残るところまで探索する。
これにより、$S_1/S_2/S_3~S_5/S_6$ の分割箇所が固定される。
aabb bbccbb ↓ aab b bbccb b aa bb bbcc bb a abb bbc cbb までいくとS2とS6が異なり、これ以上伸ばしても同じになることは無いため、終了
最後に $S_3$ の長さを決め打つ。それに続けて $S_4$ を同じ文字数だけとり、同じになるか確認する。$S_5$ が1文字以上取れるよう注意する。
aa bb bbcc bb ↓ aa bb b b cc bb
これでNIKKEI型の条件を列挙できたので、これらを全て満たすものの個数が答え。 計算量は $O(N^4)$ だが、$S_6$ と $S_3$ で一方の長さが長くなれば一方が短くなるし、文字列比較という単純な操作なので間に合う。
前もって各長さでハッシュ値を計算しておくと $O(N^3)$ になる。
文字列問題は添字の管理で混乱しがち。
s = input() n = len(s) ans = 0 s6t = s[-1] for s2t_pos in range(n - 5, 0, -1): if s[s2t_pos] != s6t: continue # print(s2t_pos, n-s2t_pos-3, (n - s2t_pos - 3) // 2 + 1) for s6len in range(1, min(n - s2t_pos - 4, s2t_pos) + 1): s2h_pos = s2t_pos - s6len + 1 s6h_pos = n - s6len if s[s2h_pos:s2t_pos + 1] != s[s6h_pos:n]: break for s3len in range(1, (n - s2t_pos - s6len - 2) // 2 + 1): s3h_pos = s2t_pos + 1 s3t_pos = s3h_pos + s3len - 1 s4h_pos = s3t_pos + 1 s4t_pos = s4h_pos + s3len - 1 if s[s3h_pos:s3t_pos + 1] != s[s4h_pos:s4t_pos + 1]: continue ans += 1 # print('', ans) print(ans)