目次
AtCoder Beginner Contest 135 D,F問題メモ
構築面白そうだからE先に解こーっととEに飛びついたのが浅はか。
D - Digits Parade
問題
- '0'~'9'または'?'からなる文字列 $S$
- '?' をそれぞれ'0'~'9'に置き換えて得られる10進数の整数のうち、「13で割ると5余る数」がいくつあるか求めよ
- 先頭に'0'がついた数字も10進数として解釈する
- $1 \le |S| \le 10^5$
解法
桁DPっぽいことをやる。
$DP[i][m]=$上の桁から $i$ 文字目までを見て、13で割ると $m$ 余る数の個数
$DP[0][0]=1$、他は0で初期化して、以下のように遷移できる
$i$ 文字目が数字の時
$i$ 桁目の数字を $k$ とする。
DP上で、それぞれの $m$ につき、$DP[i][(10m+k)\%13]+=DP[i-1][m]$ で遷移する。
これは、割り算の筆算と同じように考えるとわかりやすい。
以下の筆算で、5732の3桁目の“3”を処理するとき、上の桁からの余りが“5”なので、53を13で割った余り“1”が次の桁に引き継がれる。
__________ ____4_____ ____4_4___ 13 ) 5 7 3 2 13 ) 5 7 3 2 13 ) 5 7 3 2 _5_2_ _5_2_ 前の桁の余りを→ 5 → 5 3 10倍して新しい桁を下ろす _5_2_ 1
$DP[3][(10 \times 5 + 3)\%13] = DP[3][1] += DP[2][5]$
実際には、上の桁からの余り $m$ は“?”が存在することによって複数の可能性が考えられるので、それぞれの $m$ について同様の処理を行う。
$i$ 文字目が'?'の時
'?'を0~9のそれぞれに置き換えて、$i$ 文字目が数字の時と同じことを10回繰り返す。
高速化
前の桁の余りを10倍して13で割った余りは、下のように0~12の $m$ に対して全射である。これが、新しい1の桁が0の場合の、DPにおける次の $m$ となる。
m 0 1 2 3 4 ... 0 10 7 4 1 ...
そして、これに新しい1の位の数 $k$ を足す操作は、この全射を丸ごと $k$ だけ回転させる操作に値する。
m 0 1 2 3 4 ... 6 3 0 10 7 k=2→→
なので、$i$ 文字目が数字の時はこの全射に従って配列の要素を並べ直すだけでよく、'?'の時は、並べ直した後で0~9までの10通りの回転した配列を足し合わせる操作に値する。
0~9までずらした配列を足し合わせるのは、numpy.convolve() が使えるが、端っこがループする場合の処理はしてくれないので、そこだけ注意。
この全射は、繰り返すと6回でループする。これを利用すると、更に高速化が可能。
import numpy as np def solve(s): MOD = 10 ** 9 + 7 dp = np.zeros(13, dtype=np.int64) dp[0] = 1 idx = np.zeros(13, dtype=np.int8) for i in range(13): idx[i * 10 % 13] = i window = np.ones(10, dtype=np.int8) for c in s: if c == '?': tdp = dp[idx] ndp = np.concatenate([tdp[4:], tdp]) dp = np.convolve(ndp, window, mode='valid') % MOD else: dp = np.roll(dp[idx], int(c)) return dp[5] s = input() print(solve(s))
E - Golf
問題
解法
解けはしたがめっちゃ冗長なコードだし解説pdfでこと済む。
F - Strings of Eternity
問題
- 英小文字からなる二つの文字列 $s,t$
- $s$ を無限個つなげてできる文字列を $s_\infty$ とする
- $t$ を $i$ 個つなげた文字列を $t_i$ とする
- 以下の条件を満たす $i$ が有限かどうか判定し、有限ならその最大値を求めよ
- 条件: $s_\infty$ の連続する部分文字列に $t_i$ が出現する
- $1 \le |s|, |t| \le 5 \times 10^5$
解法
文字列アルゴリズム。
正確な解法としてはKMPとかZ-Algorithmとかあるらしいが、シンプルなローリングハッシュで解く。 ローリングハッシュは文字列が同じかどうかを高速に判定できるが、ハッシュなので衝突の可能性があり、テストケースによってはWAになる。
ローリングハッシュは、今回は英小文字のみなので $a=0,b=1,...,z=25$ とした26進数として文字列を見て、文字列を1つの数値として表現する。大きくなりすぎるので適当な数でMODをとる。このMODの取り方によってはハッシュが衝突する。
文字列 b c d 対応 1 2 3 桁 26^2 26 1 676+52 +3 = 731 'bcd' を表す整数は731
こうすると、たとえば調べたい文字列の長さが3と決まっている時、枠を1つずらすのに、事前の計算結果を利用できる。
文字列 c d e 対応 2 3 4 桁 26^2 26 1 'bcd'=731 が計算済み 桁を1つ繰り上げるには26倍する bcd0 = 731 * 26 = 19006 新しい1の位'e'を足す bcde = 19006 + 4 = 19010 4桁目(26^3の桁)の'b'を引く cde = 19010 - 1*26^3 = 1434
さて、問題について考える。
まず、文字列 $s$ のそれぞれの箇所で、文字列 $t$ と一致するかを調べたい。
s abcabab t ab 一致 xoxxoxo (tの末尾文字を合わせるindexで考える)
$s$ を対象として $t$ を検索するのに、$s$ が短すぎては困る。 $i$ が有限である場合を考えると、$s$ の長さは、末尾と先頭がくっつくケースを考えても $t$ の2倍程度あれば十分なので、足りない場合は適当にそれ以上になるように $s$ をあらかじめ連結させておく。
これで、$t$ 全体のローリングハッシュと、$s$ の長さ $|t|$ のローリングハッシュを計算すると、ハッシュ値が一致する箇所が、そこを末尾として$s$ が $t$ に一致する箇所となる。
t ab |t| 2 tのハッシュ値: 1 s a b c a b a b a b c a b a b rs 0 1 28 52 1 26 1 26 1 28 52 1 26 1 一致 x o x x o x o
ただし、$s$ の最初の方は1つ前のループの文字を上手く反映できてないので、$s$ は長さを $t$ の2倍(以上)にした上で、後半の計算結果のみ利用する。
一致した配列から、連続を見ていく。
i 0 1 2 3 4 5 6 s a b c a b a b rs 26 1 28 52 1 26 1 一致 o o o
まず、$i=1$ で一致している。続けて次の文字列でも一致しているかどうか調べるには、$i$ に $|t|=2$ を足して、$i=3$ で一致しているかどうかを見ればよい。
この場合は一致していないので、$i=1$ の後には続けては $t$ は来ない。
次に、$i=4$ で一致している。この場合は次の $i=6$ でも一致し、さらにループした先の $i=1$ でも一致している。$i=4$ から始まる一致は、最大3個連続して $t$ が出現する。
こうして、一致箇所を順に見ていって(計算結果をメモしつつ)最大連続数を求めればよい。
さて、ここで無限の場合を考える。
$t$ が無限回繰り返せる場合、$i$ で一致し、次の $i+|t|$ で一致し、さらに次の $i+2|t|$ で一致し……、いつかは最初の $i$ に戻ってきてしまうような箇所がどこかに存在する。こうなると無限回繰り返せることがわかる。
(実際に $i$ に一致するまでチェックを回すとTLEしたので、$i$ 以上 $i+|t|$ 未満な数の範囲にループして戻ってきた場合、無限であると判定するようにしたらACしたが、これで本当に大丈夫かは確認できていない)
import sys sys.setrecursionlimit(10 ** 6) def rolling_hash(s, w, MOD): ret = [] tmp = 0 p = pow(26, w, MOD) ords = [ord(c) - 97 for c in s] for i, o in enumerate(ords): tmp = tmp * 26 + o if i >= w: tmp = (tmp - ords[i - w] * p) tmp %= MOD ret.append(tmp) return ret def solve(s, t): MOD = 10 ** 9 + 7 ls, lt = len(s), len(t) k = (lt - 1) // ls + 1 s *= k * 2 ls *= k rs, rt = rolling_hash(s, lt, MOD), rolling_hash(t, lt, MOD) print(rs) rs = rs[ls:] ht = rt[-1] checked = [-1] * ls def series(i, st): if st <= i < st + lt: return float('-inf') if checked[i] == -1: checked[i] = series((i + lt) % ls, st) + 1 if rs[i] == ht else 0 return checked[i] for i, hs in enumerate(rs): if hs != ht: continue ret = series((i + lt) % ls, i) if ret == float('-inf'): return -1 checked[i] = ret + 1 return max(0, max(checked)) s = input() t = input() print(solve(s, t))