Processing math: 51%

第一回日本最強プログラマー学生選手権 決勝 A,B,C,E問題メモ

A - Equal Weight

問題

  • N 個のネタと M 個のシャリを組み合わせて寿司を作る
  • ネタの重さは A0,A1,...,AN1、シャリの重さは B0,B1,...,BM1
  • ネタ i とシャリ j を組み合わせて作った寿司の重さは Ai+Bj
  • 同じ重さの寿司を2つ作れるか判定し、作れる場合はそれぞれ何番目のネタとシャリを組み合わせればよいか、一例を答えよ
  • 1N,M2×105
  • 1Ai,Bi106

解法

見かけ倒し問題。

一見、全てのネタとシャリの組み合わせ N×M 通りを調べなければならないように思うが、 「寿司の重さが取り得る値の範囲は 22×106」ということに着目すると、 多くとも 2×106 個の組み合わせを調べればどれか1組は絶対に被る。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
n, m = map(int, input().split())
aaa = list(map(int, input().split()))
bbb = list(map(int, input().split()))
 
memo = [-1] * (2 * 10 ** 6 + 1)
 
for i, a in enumerate(aaa):
    for j, b in enumerate(bbb):
        w = a + b
        if memo[w] == -1:
            memo[w] = (i, j)
            continue
        k, l = memo[w]
        print(i, j, k, l)
        exit()
 
print(-1)

B - Reachability

問題

  • 3N 頂点からなる有向グラフがある
  • 頂点は N 個ずつX,Y,Zグループに分けられ、x0,x1,...,xN1,y0,y1,...,yN1,z0,z1,...,zN1 と名前が付いている
  • 辺は全てX→Y、Y→Zに張られている
  • 2つの N×N の表 A0,0,...,AN1,N1B0,0,...,BN1,N1 が与えられる
  • xi から yj へ行ける時、Ai,j は'1'であり、行けない時'0'である
  • xi から zj へ行ける時、Bi,j は'1'であり、行けない時'0'である
  • これを満たすようなグラフが存在するか判定し、存在する場合は YZ への辺としてあり得るものを1つ構築せよ
  • 1N300

解法

xiyj に行けるのに、xizk には行けない場合、yjzk には辺があってはいけない。

これを満たすように構築すれば、ひとまず B のうち「行けない」条件は満たされる。

あとは「行ける」条件が全て満たされているかを確認すればよい。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def solve(n, aaa, bbb):
    reachable_xy = [{y for y, f in enumerate(a) if f == 1} for a in aaa]
    not_reachable_xz = [{z for z, f in enumerate(b) if f == 0} for b in bbb]
    ans = [[1] * n for _ in range(n)]
 
    for x in range(n):
        for y in reachable_xy[x]:
            for z in not_reachable_xz[x]:
                ans[y][z] = 0
 
    reachable_yz = [{z for z, f in enumerate(a) if f == 1} for a in ans]
 
    for x in range(n):
        reachable = [0] * n
        for y in reachable_xy[x]:
            for z in reachable_yz[y]:
                reachable[z] = 1
        if bbb[x] != reachable:
            print(-1)
            return
 
    for row in ans:
        print(''.join(map(str, row)))
 
 
n = int(input())
aaa = [list(map(int, input())) for _ in range(n)]
bbb = [list(map(int, input())) for _ in range(n)]
 
solve(n, aaa, bbb)

C - Maximize Minimum

問題

  • 0L の数直線と見なせるロールケーキがあり、はじめ、座標 X にイチゴが置かれている
  • 今から N 回の操作を行う。i 回目の操作では、以下を行う
    • 座標 Ai にイチゴが置かれていたら取り除き、置かれていなければ新たに置く
  • 各操作後、ケーキの上にはイチゴが2個以上あることが保証される
  • ケーキの「美しさ」を以下で定義する
    • 座標 x にあるイチゴは、Lx にイチゴが無ければ、移動できる
    • 「最も近い2つのイチゴの距離」を最大化するように移動させたとき、その最大値をケーキの美しさとする
  • 美しさを求めるときのイチゴの移動は想像の中で行い、実際に移動させることはない
  • 全ての i(0iN1) について、i 回目の操作直後のケーキの美しさを求めよ
  • 1N2×105
  • 1L109

解法

平衡二分探索木問題。

最適な配置の考察

まず、各操作後、どこにイチゴが置かれているかは、その通り実装すれば簡単にわかる。

美しさを求める段階で xLx に相互に行き来できるので、全て左半分に(L/2 以下になるように)寄せてしまっておく。

これを小さい方から x1,x2,...,xk とする。

すると、なるべく2つのイチゴを遠ざけようとするなら、xi で隣接する2つのイチゴは左半分と右半分に別々に配置した方がよさそう。 つまり、交互に配置するのがよさそう。

0-------------------------------------------L
  x1 x2 x3 x4 x5 x6
↓
  x1    x3    x5          x6    x4    x2

これが最適な証明は、どこか隣接する2つが同じ側に位置していたとして(x3,x4)、 そこを切り離して x4 以降をまるっと反転させたときを考える。

0-------------------------------------------L
  x1    x3 x4             x6 x5       x2
↓
  x1    x3    x5 x6             x4    x2

変化があるのは x3x4x3x5x2x5x2x4 の2つのみで、このうち x3 の方は明らかに長くなっている。

x2 の方は短くなっているが、変化後の x2x4 の方が元にあった x3x4 より必ず長いか同じなので、改悪となることは無い。

よって、これを繰り返せば、交互に位置させるのがよいとわかる。

クエリ毎に答えを求められるデータ構造

後はこれを如何に高速に求めるか。

これには、イチゴ1個の追加や削除によって、“2つのイチゴの距離”の集合は、高々数個しか変化しないことを利用する。

イチゴが追加削除されることにより、以降の偶数奇数は変わってしまうが、 各イチゴにとって、「自分の近くのイチゴが追加削除されない限り、自分が関係する距離は、自分の2つ前後のイチゴとの距離」というのは変わらない。

「1個飛ばしの全てのイチゴ間距離の集合」を持っておく。

追加
  ,-----,             今、この4つだけに着目すると、
x5  x6  x8  x9        集合にある距離は x5-x8 と x6-x9
     `------'
↓
 ,------,,------,     x7が追加されると、
x5  x6  x7  x8  x9    元あった2つの距離は削除され
     `------'         x5-x7, x7-x9, x6-x8 が追加される
削除(追加の逆)
  ,-----,,------,     今、この5つだけに着目すると、
x5  x6  x7  x8  x9    集合にある距離は
     `------'         x5-x7, x7-x9, x6-x8
↓
 ,------,             x7が削除されると、
x5  x6  x8  x9        元あった3つの距離は削除され
     `------'         x5-x8, x6-x9 が追加される

端っこでイチゴが存在しない場合などを考慮する必要はあるが、最大でも5組の距離の追加削除を行えば更新できる。あとは集合のminを取ればよい。

これは、「(A) xi を管理する」「(B) イチゴの距離を管理する」2つの平衡二分探索木を利用すれば実装できる。

  • Ai にイチゴが無いとき(追加)
    • xk=min とする
    • (A)に x_k を挿入し、今存在するイチゴの中での小さい方からのインデックス k を得る
    • (A)から前後それぞれ2つのイチゴの位置 x_{k-2},x_{k-1},x_{k+1},x_{k+2} を得る
    • 削除される距離、追加される距離をそれぞれ計算し、(B)から追加削除する
    • (B)のminを得る
  • A_i にイチゴがある時(削除)も同様

ただし、最も中央に近い2つが距離最小となる場合もあるため、それは(A)の最大2つを得て確認する。

0-------------------------------------------L
  x1     x2     x3x4 
↓
  x1            x3   x4          x2

Pythonだと

リストアクセスの遅いPythonでは、平衡二分探索木を実装してもなかなかTLEが解消できない。

ここで必要なデータ構造をもう一度考える。

(A) は、ある値に基づく挿入・削除と、その前後のノードの取得が必要となるので、 常に全体が整列された状態を保つ必要があり、なかなか平衡二分探索木以外では代用しにくい(と思う)。

しかし(B)は、挿入と削除を必要とするものの、値を得るのは常に最小値でよい。

よって優先度キューと、その値が有効か否かを管理するdictの2つによって代用できる。これらは組み込みライブラリのため高速に処理できる。

平衡二分探索木は、木の回転が少なく済むTreapとして実装してみた。

細かい定数倍改善を思いつくまま未検証でいろいろ入れてるので長くなってしまっている。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
import sys
 
from collections import defaultdict
from heapq import heappush, heappop
 
 
class Treap:
    """ Merge-Split Model Treap """
 
    # Usage example:
    # trp = Treap()
 
    # trp.insert(x)
    # trp.delete(x)
    # trp.upper_bound(x)
    # trp.lower_bound(x)
    # trp.get_by_index(i)
    # trp.get_next(i, node)
    # trp.get_prev(i, node)
    # trp.debug_print()
 
    # A node is represented as common index of 5 arrays that is left, right, key, priority, count.
    # Dummy node is represented as 0.
    # Maybe it is faster than implementing node as class or as array. (unverified)
 
    # Many deletion after many insertion may increase the number of unused indices.
    # To reduce it, use deleted index on the next insertion preferentially.
 
    # Variables name
    #   vi: current node index on search
    #   li: left child node index
    #   ri: right child node index
    #   pi: parent node index
    #   i : the i-th smaller node. Not index of arrays.
    #   x : key to insert or to delete
 
    def __init__(self):
        self.root = 0
        self.left = [0]
        self.right = [0]
        self.children = [self.left, self.right]
        self.key = [0]
        self.priority = [0]
        self.count = [0]
        self.deleted = set()
        self.rand = self.xor128()
 
    def xor128(self):
        x = 123456789
        y = 362436069
        z = 521288629
        w = 88675123
 
        while True:
            t, x, y, z = x ^ ((x << 11) & 0xffffffff), y, z, w
            w = (w ^ (w >> 19)) ^ (t ^ (t >> 8))
            yield w
 
    def get_size(self):
        return self.count[self.root]
 
    def _merge(self, li, ri):
        left = self.left
        right = self.right
        children = self.children
        priority = self.priority
        count = self.count
 
        stack = []
        while li != 0 and ri != 0:
            if priority[li] > priority[ri]:
                stack.append((li, 1))
                li = right[li]
            else:
                stack.append((ri, 0))
                ri = left[ri]
 
        vi = li if li != 0 else ri
        for pi, d in reversed(stack):
            children[d][pi] = vi
            count[pi] = count[left[pi]] + count[right[pi]] + 1
            vi = pi
 
        return vi
 
    def _split_by_key(self, vi, x):
        """
        :return: (LeftRoot, RightRoot)
                  LeftRoot:  Root node of the split tree consisting of nodes with key < x.
                  RightNode: Root node of the split tree consisting of nodes with key >= x.
        """
        left = self.left
        right = self.right
        key = self.key
        count = self.count
 
        l_stack = []
        r_stack = []
        while vi != 0:
            if x < key[vi]:
                r_stack.append(vi)
                vi = left[vi]
            else:
                l_stack.append(vi)
                vi = right[vi]
 
        li, ri = 0, 0
        for pi in reversed(l_stack):
            right[pi] = li
            count[pi] = count[left[pi]] + count[li] + 1
            li = pi
        for pi in reversed(r_stack):
            left[pi] = ri
            count[pi] = count[ri] + count[right[pi]] + 1
            ri = pi
 
        return li, ri
 
    def insert(self, x):
        left = self.left
        right = self.right
        children = self.children
        key = self.key
        priority = self.priority
        count = self.count
 
        np = next(self.rand)
 
        if self.deleted:
            ni = self.deleted.pop()
            left[ni] = 0
            right[ni] = 0
            key[ni] = x
            priority[ni] = np
            count[ni] = 1
        else:
            ni = len(self.key)
            left.append(0)
            right.append(0)
            key.append(x)
            priority.append(np)
            count.append(1)
 
        vi = self.root
        pi = 0
        d = 0
 
        while vi != 0:
            if np > priority[vi]:
                li, ri = self._split_by_key(vi, x)
                left[ni] = li
                right[ni] = ri
                count[ni] = count[li] + count[ri] + 1
                break
            pi = vi
            d = int(x >= key[vi])
            count[vi] += 1
            vi = children[d][vi]
 
        if pi == 0:
            self.root = ni
        else:
            children[d][pi] = ni
 
    def delete(self, x):
        left = self.left
        right = self.right
        children = self.children
        key = self.key
        count = self.count
 
        vi = self.root
        pi = 0
        d = 0
 
        while vi != 0:
            if key[vi] == x:
                self.deleted.add(vi)
                vi = self._merge(left[vi], right[vi])
                break
            pi = vi
            d = int(x >= key[vi])
            count[vi] -= 1
            vi = children[d][vi]
 
        if pi == 0:
            self.root = vi
        else:
            children[d][pi] = vi
 
    def upper_bound(self, x):
        """
        :return (Node, i)
                 Node: with the smallest key y satisfying x < y.
                 i: 0-indexed order.
                 If same keys exist, return leftmost one.
                 If not exists, return (0, n).
        """
        left = self.left
        right = self.right
        key = self.key
        count = self.count
 
        vi = self.root
        ti = 0
        i = count[vi]
        j = 0
        while vi != 0:
            if x < key[vi]:
                ti = vi
                i = j + count[left[vi]]
                vi = left[vi]
            else:
                j += count[left[vi]] + 1
                vi = right[vi]
        return ti, i
 
    def lower_bound(self, x):
        """
        :return (Node, i)
                 Node: with the smallest key y satisfying x <= y.
                 i: 0-indexed order.
                 If same keys exist, return leftmost one.
                 If not exists, return (0, n).
        """
        left = self.left
        right = self.right
        key = self.key
        count = self.count
 
        vi = self.root
        ti = 0
        i = count[vi]
        j = 0
        while vi != 0:
            if x <= key[vi]:
                ti = vi
                i = j + count[left[vi]]
                vi = left[vi]
            else:
                j += count[left[vi]] + 1
                vi = right[vi]
        return ti, i
 
    def get_by_index(self, i):
        """
        :return (0-indexed) i-th smallest node.
                If i is greater than length, None will be returned.
        """
        left = self.left
        right = self.right
        count = self.count
 
        if i < 0 or self.get_size() <= i:
            return 0
        vi = self.root
        j = i
        while vi != 0:
            l_cnt = count[left[vi]]
            if l_cnt == j:
                return vi
            if j < l_cnt:
                vi = left[vi]
            else:
                j -= l_cnt + 1
                vi = right[vi]
 
        assert False, 'Unreachable'
 
    def get_max(self):
        # 多くの場合において処理が単純な分 get_by_index(get_size - 1) より速いが、
        # テストケースによっては途中で処理を打ち切れる get_by_index の方が速いことがある
        right = self.right
        vi = self.root
        if vi == 0:
            return 0
        while right[vi] != 0:
            vi = right[vi]
        return vi
 
    def get_next(self, i, vi):
        """
        :return: next node of i-th "node". (= (i+1)th node)
        """
        # If node has right child, the root-node search can be omitted.
        # Otherwise, get_by_index(i+1).
        left = self.left
        right = self.right
 
        if vi == 0:
            return 0
        if right[vi] == 0:
            return self.get_by_index(i + 1)
        vi = right[vi]
        while left[vi] != 0:
            vi = left[vi]
        return vi
 
    def get_prev(self, i, vi):
        left = self.left
        right = self.right
 
        if vi == 0:
            return 0
        if left[vi] == 0:
            return self.get_by_index(i - 1)
        vi = left[vi]
        while right[vi] != 0:
            vi = right[vi]
        return vi
 
    def debug_print(self):
        self._debug_print(self.root, 0)
 
    def _debug_print(self, vi, depth):
        if vi != 0:
            self._debug_print(self.left[vi], depth + 1)
            print('      ' * depth, self.key[vi], self.priority[vi], self.count[vi])
            self._debug_print(self.right[vi], depth + 1)
 
 
def dist_insert(x):
    heappush(dists, x)
    available_dists[x] += 1
 
 
def dist_delete(x):
    available_dists[x] -= 1
 
 
def dist_get_min():
    while dists and available_dists[dists[0]] == 0:
        heappop(dists)
    if dists:
        return dists[0]
    return 0xffffffff
 
 
n, l, x = map(int, input().split())
aaa = list(map(int, sys.stdin))
 
trp1 = Treap()
dists = []
available_dists = defaultdict(lambda: 0)
trp1key = trp1.key
strawberry = {x}
trp1.insert(min(x, l - x))
 
buf = []
for a in aaa:
    b = min(a, l - a)
    if a in strawberry:
        vi, i = trp1.lower_bound(b)
        li1 = trp1.get_prev(i, vi)
        li2 = trp1.get_prev(i - 1, li1)
        ri1 = trp1.get_next(i, vi)
        ri2 = trp1.get_next(i + 1, ri1)
        l1, l2, r1, r2 = trp1key[li1], trp1key[li2], trp1key[ri1], trp1key[ri2]
        if li2 != 0:
            dist_delete(b - l2)
            if ri1 != 0:
                dist_insert(r1 - l2)
        if ri2 != 0:
            dist_delete(r2 - b)
            if li1 != 0:
                dist_insert(r2 - l1)
        if li1 != 0 and ri1 != 0:
            dist_delete(r1 - l1)
        strawberry.remove(a)
        trp1.delete(b)
    else:
        strawberry.add(a)
        trp1.insert(b)
        vi, i = trp1.lower_bound(b)
        li1 = trp1.get_prev(i, vi)
        li2 = trp1.get_prev(i - 1, li1)
        ri1 = trp1.get_next(i, vi)
        ri2 = trp1.get_next(i + 1, ri1)
        l1, l2, r1, r2 = trp1key[li1], trp1key[li2], trp1key[ri1], trp1key[ri2]
        if li2 != 0:
            dist_insert(b - l2)
            if ri1 != 0:
                dist_delete(r1 - l2)
        if ri2 != 0:
            dist_insert(r2 - b)
            if li1 != 0:
                dist_delete(r2 - l1)
        if li1 != 0 and ri1 != 0:
            dist_insert(r1 - l1)
 
    size = trp1.get_size()
    ci2 = trp1.get_by_index(size - 1)
    ci1 = trp1.get_prev(size - 1, ci2)
    buf.append(min(dist_get_min(), l - trp1key[ci1] - trp1key[ci2]))
 
print(*buf)

E - Nearest String

問題

  • N 個の文字列 S_0,S_1,...,S_{N-1} がある
  • Q 個の文字列 T_0,T_1,...,T_{Q-1} がある
  • T_i について、以下の操作を繰り返して S のいずれかに一致させる時の、最小コストを求めよ
    • 先頭または末尾から1文字削除する。コスト X かかる
    • 末尾に好きな文字を1文字追加する。コスト Y かかる
  • 1 \le N,Q \le 10^5
  • 1 \le X,Y \le 10^9
  • S の全文字数の合計、T の全文字数の合計はそれぞれ 5 \times 10^5 を越えない

解法

最悪、T_i を全て消してから S_i に一致するように1文字ずつ追加すれば任意の文字列を作れるのだが、 何文字か共通文字列を残すことで、コストが削減できそう。

その場合、削除は両方からできるが、追加は末尾にしかできないことに注意。

つまり、T_i を何文字か残して S_i に一致させられるとしたら、残す文字列は、そのまま S_i の先頭部分に一致している必要がある。

Si  abcdefg
Ti  xyzabcdhij
(Tiに、Siの先頭4文字abcdが存在している)
→  abcd を残す  →  abcdefg  にする

Si  abcdefg
Ti  xyzbcdefghij
(Tiに、Siの先頭からの共通部分は1文字も存在しない)
→  bcdefgは一致しているが使えない。全て消して、1文字ずつ追加するしか無い

ある T_i の中に、複数の文字列 S_0,S_1,... のいずれかが出現するかどうかを効率的に調べるには、Aho-Corasick法が使えるが、 その過程で、本問題を解くのに必要な「複数の文字列のprefixがどこまで出現するか」も調べることができる。

trie木のノードには、以下の情報を持たせる。まずは S_0,S_1,... からこれを構築する。

  • 末尾に1文字くっつけたノードへの各リンク(trie木なら当然存在)
  • 最長suffixノードへのリンク(Aho-Corasick法なら当然存在)
  • 自身の表す文字列の長さ
  • 自身の表す文字列から末尾に付け足して S_i のいずれかに一致させる時、必要な最小残り文字数

T_i を1文字ずつ見る。注目中の文字を c とする。「現在ノード」を根からスタートして、通常のAho-Corasick法と同様、以下の要領でノードを移動していく。

  • 現在ノードから伸びる c のリンクがあれば、それを辿って次のノードに移動
  • なければ、c のリンクが存在するノードまでsuffixリンクを辿って戻り、c のリンクを辿ったノードに移動
  • 根まで辿っても無ければ、根とする

各局面で、今いるノードは「T_i の先頭~ c までの文字列のsuffix」と「S_i のいずれかのprefix」の、最も長く共通する文字列を示す。

S1  abcac
S2  cad
S3  dad
T   dabcad

Tをここまで見た  現在ノード  説明
                 (根)
d                d           Tの1文字目と S3の1文字目が一致
da               da          Tの1~2文字目と S3の1~2文字目が一致
dab              ab          S3とは一致しなくなった。Tの2~3文字目とS1の1~2文字目が一致
dabc             abc         Tの2~4文字目と S1の1~3文字目が一致
dabca            abca        Tの2~5文字目と S1の1~4文字目が一致
dabcad           cad         S1とは一致しなくなった。Tの4~6文字目とS2の1~3文字目が一致

各現在ノードで、「T_i から削除した際に残すのが自身だった時、必要なコスト」を計算する。

たとえば'abc'なら、削除する文字数は 「T_i の文字数 6 - 自身の文字数 3 = 3」、追加する文字数は、S_1 に一致させるのに残り2文字と事前計算しているので、2文字。 よって、3X+2Y がコストとなる。

これの最小値をとればいい……かというと、見落としているパターンがある。

X = 1,  Y = 100
S1  abcd
S2  abcabcdef
T   abcabcd

これの最小コストは、3文字削って S_1 に一致させる「3」である。 しかし、ノードを辿っていった時、S_2 に最後まで一致してしまうので、「abcd」を表すノードには訪れられず、abcdを残すパターンのコストは計算されない。

このノードをきちんと訪れるには、各現在ノードに付き「そこからsuffixリンクを根まで辿るまでに訪れるノードも、コスト計算する」とよい。

上記の例なら、'abcabcd' のsuffixリンクが 'abcd' に張られているので、そこでコスト計算することができる。

まとめる

しかしそうすると、今度は S によっては、毎回多くのノードを遡る必要が生じ、計算量的に間に合わなくなる。

ここで、コストの計算式をもう一度よく見ると、

  • (T_i の文字列長 - ノードの文字列長) \times X + ノードに追加する文字列長 \times Y
  • =T_i の文字列長 \times X + (ノードに追加する文字列長 \times Y - ノードの文字列長 \times X)

となり、「T_i の文字列長 \times X」 はノードに依存せず、後ろの括弧でくくった部分は T_i に依存しない。

k 番目のノードの、括弧でくくった部分の値を C_k で表すとすると、 ノードの優劣(そのノードの示す文字列を残した時にコストが低くなるかどうか)は T_i に関係なく C_k の比較だけで行える。

そこで、trie木構築の段階で、各ノードの“総合コスト”を、「自身と、自身からsuffixリンクを辿った中で、C_k の最小値」として事前計算しておく。

すると、T_i を1文字ずつ見る段階で毎回suffixを辿る必要は無くなり、訪れるノードの総合コストの比較のみで、最小コストを求めることができる。

実装メモ

trie木は、ノードをクラスや配列で表し、そこに子ノードやsuffixノードへのリンク、各種情報を格納することもできるが、 各情報をそれぞれ独立した配列でもって、indexで管理した方が速い。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import sys
from collections import deque
 
 
class AhoCorasick:
    def __init__(self, needles, x, y):
        self.INF = 10 ** 18
        self.children = [{}]
        self.depth = [0]
        self.append_cost = [self.INF]
 
        for needle in needles:
            self._register(needle)
 
        self.append_cost = [a * y - d * x for d, a in zip(self.depth, self.append_cost)]
 
        self.suffix, self.calc_cost = self._create_failure()
 
    def _register(self, needle):
        k = 0
        stack = [k]
        for i, c in enumerate(needle, start=1):
            if c in self.children[k]:
                k = self.children[k][c]
            else:
                j = len(self.children)
                self.children[k][c] = j
                self.children.append({})
                self.depth.append(i)
                self.append_cost.append(self.INF)
                k = j
            stack.append(k)
        stack.reverse()
        for d, k in enumerate(stack):
            if self.append_cost[k] > d:
                self.append_cost[k] = d
            else:
                break
 
    def _create_failure(self):
        children = self.children
        append_cost = self.append_cost
 
        suffix = [0] * len(children)
        min_cost = [0] * len(children)
 
        min_cost[0] = append_cost[0]
 
        q = deque()
        for k in children[0].values():
            suffix[k] = 0
            min_cost[k] = min(append_cost[k], min_cost[0])
            q.append(k)
 
        while q:
            k = q.popleft()
            for c, j in children[k].items():
                b = suffix[k]
                while True:
                    if c in children[b]:
                        suffix[j] = children[b][c]
                        min_cost[j] = min(append_cost[j], min_cost[suffix[j]])
                        break
                    if b == 0:
                        suffix[j] = 0
                        min_cost[j] = min(append_cost[j], min_cost[0])
                        break
                    b = suffix[b]
                q.append(j)
 
        return suffix, min_cost
 
    def search_cost(self, haystacks, x):
        children = self.children
        suffix = self.suffix
        calc_cost = self.calc_cost
 
        buf = []
        for haystack in haystacks:
            k = 0
            cost = calc_cost[0]
            for c in haystack:
                while True:
                    if c in children[k]:
                        k = children[k][c]
                        break
                    if k == 0:
                        break
                    k = suffix[k]
 
                cost = min(cost, calc_cost[k])
 
            buf.append(len(haystack.rstrip()) * x + cost)
 
        return buf
 
 
n, q, x, y = map(int, input().split())
lines = sys.stdin.readlines()
sss = [s.rstrip() for s in lines[:n]]
ahc = AhoCorasick(sss, x, y)
 
print(*ahc.search_cost(lines[n:], x))

programming_algorithm/contest_history/atcoder/2019/0929_jsc2019_final.txt · 最終更新: 2019/10/16 by ikatakos
CC Attribution 4.0 International
Driven by DokuWiki Recent changes RSS feed Valid CSS Valid XHTML 1.0