ARC 090
C - Candies
問題
- 2列 $N$ 行のグリッドにアメが落ちている
- 左上から右下に、右と下のみに移動してアメを拾う
- 拾うアメを最大化せよ
解法
下への移動は1回のみなので、どこで下に移動するかを全探索。$i$ の位置で移動するとすると、
┌─┬─┬ ┬──┬─┬──┬ ┬──┬─┐ │a1│a2│ ... │ai-1│ai│ │ ... │ │ │ ├─┼─┼ ┼──┼─┼──┼ ┼──┼─┤ │ │ │ ... │ │bi│bi+1│ ... │bn-1│bn│ └─┴─┴ ┴──┴─┴──┴ ┴──┴─┘ = ┌─┬─┬ ┬──┬─┬──┬ ┬──┬─┐ │a1│a2│ ... │ai-1│ai│ │ ... │ │ │ ├─┼─┼ ┼──┼─┼──┼ ┼──┼─┤ │b1│b2│ ... │bi-1│bi│bi+1│ ... │bn-1│bn│=SUM(b) └─┴─┴ ┴──┴─┴──┴ ┴──┴─┘ ひく ┌─┬─┬ ┬──┬─┬──┬ ┬──┬─┐ │ │ │ ... │ │ │ │ ... │ │ │ ├─┼─┼ ┼──┼─┼──┼ ┼──┼─┤ │b1│b2│ ... │bi-1│ │ │ ... │ │ │ └─┴─┴ ┴──┴─┴──┴ ┴──┴─┘
この考え方で、$i$ で下に移動する時に得られる個数を、累積和で楽に求められる。$i$ を1~Nまで動かして最大値。
from itertools import accumulate
n = int(input())
a1 = list(accumulate(map(int, input().split())))
a2 = list(accumulate(map(int, input().split())))
s2 = a2[-1]
ans = a1[0]
for i in range(1, n):
ans = max(ans, a1[i] - a2[i - 1])
print(ans + s2)
D - People on a Line
問題
- $x$ 軸上の $0 \le x \le 10^9$ の区間に $N$ 人が立っている。
- 各人が立っている位置の $x$ 座標は整数。
- 「$R_i$ 番目の人は $L_i$ 番目の人より $D_i$ 右にいる」という条件が $M$ 個ある
- 条件が全て正しい並べ方はあるか判定せよ
- $1 \le N \le 10^5$
- $1 \le D_i \le 10^4$
解法
$L_i$ と $R_i$ の間に双方向に距離 $D_i$ の辺を張ったグラフとして考える。
グラフの各連結成分につき、適当に基準となる人を決めて仮位置 $x_{L_i}=0$ とし、辺が繋がる限り探索して
- まだ位置が未確定の人なら $x_{L_i}$ からの相対位置を確定
- 位置が確定済みの人なら $x_{L_i}$ からの相対位置に矛盾が無いか確認
を繰り返す。
連結成分内の座標の最大値と最小値の差が $10^9$ 以下なら、全体を適当にずらすことで条件を満たすように出来る。 (※ただし、よく見たら問題の制約条件で、$N=10^5$ 人が全て $D_i=10^4$ の距離で並んでいても最大の差は $10^9$ なので、不要な確認だった)
Pythonでは時間的に速くとも1秒くらいかかるので、考え方自体はあっていてもデータの持たせ方で下手なことをするとTLEになってしまう。
import sys
sys.setrecursionlimit(100000)
def check(i, xs, checked):
xi = xs[i]
children = set()
for j, d in links[i]:
if checked[j]:
continue
if j not in xs:
xs[j] = xi + d
elif xi + d != xs[j]:
return False
children.add(j)
checked[i] = True
for j in children:
if not check(j, xs, checked):
return False
return True
def solve():
checked = [False] * n
for i in range(n):
if not checked[i]:
xs = {}
xs[i] = 0
res = check(i, xs, checked)
if not res:
return False
mn, mx = min(xs.values()), max(xs.values())
if mx - mn > 1e9:
return False
return True
n, m = map(int, input().split())
links = [set() for _ in range(n)]
for _ in range(m):
l, r, d = map(int, input().split())
l -= 1
r -= 1
links[l].add((r, d))
links[r].add((l, -d))
print('Yes' if solve() else 'No')
E - Avoiding Collision
問題
- $N$ 頂点 $M$ 辺の重み付き双方向グラフ
- 辺 $(U_i,V_i,D_i)$ は、$U_i \rightarrow V_i$、$V_i \rightarrow U_i$ いずれの方向に通るにも $D_i$ 分かかる
- 2点 $S,T$ が指定される
- 青木君は頂点 $S$ にいて $T$ に最短時間で移動を開始
- 高橋君は頂点 $T$ にいて $S$ に最短時間で移動を開始
- 2人がぶつからない経路の組み合わせは何通りあるか、$\mod 10^9+7$ で答えよ
方針
Pythonだと計算量的にきつい。scipyとか駆使して少し可能性がある、というところか。 scipyの受け入れる形にデータを揃えるのがまず時間かかるし、この問題では経路探索の途中経過で同時に計算した方がいい要素があるので、scipy.sparse.csgraph は不向き。
- ポイント
- 双方向ダイクストラ
- 辺を広げる条件、終了条件の理解
- 入力の受け取り方
解法としては、双方向ダイクストラがいいと思う。ただ、双方向ダイクストラの途中の各段階で、キューに残っているノードと確定してないノードはどういうものか、というのを意識しないと、混乱する。
また、普段余り意識することは無かったが、入力が数十万行に及ぶ問題では、受け取り方で計算時間がかなり変わってくることも知った。
普段、1行ずつ読み込む時は、input()を使うが、これが数十万とか繰り返される場合は、sys.stdin.readlines()をイテレートして1行ずつ取得した方が1.5倍くらい速い。他の処理がどうしても遅くなるpythonでは、この差は結構でかい。あと少しでTLEが解消できない時は、ここを変えるのが良い。
具体的な処理手順としては、
visited_fwd[v]: ノード $v$ の順方向探索における $S$ からの距離visited_bwd[v]: ノード $v$ の逆方向探索における $T$ からの距離patterns_fwd[v], patterns_bwd[v]: 順逆方向探索それぞれにおけるノード $v$ までの最短経路のパターン数collision_nodes, collision_links: 衝突する「可能性のある」ノードとリンク
以上のデータ構造を用意し、双方向探索。ただし、
- 優先キューに保存する情報は、以下の5つ
- ノード $v$ までのコスト(順方向なら$S$から、逆方向なら$T$から)
- $v$ の1つ前のノード $a$ までのコスト
- ノード$v$
- ノード$a$
- 順方向探索か逆方向探索か
- 優先度の比較に使うのは「$v$までのコスト」のみでよい
- 2段階に分けて行う
- 1つめは、リンクを広げつつ、最短距離を確定する
- 2つめは、リンクは広げず、段階1でキューに残るノードから衝突の可能性のある箇所を調べる
- 第1段階は、探索中のノード$v$で、反対方向の探索で既に探索済みのものが出てきたら終了
- 最短距離は順方向からの距離+逆方向からの距離で確定
- 順逆方向距離が等しければ衝突ノード候補にvを、異なれば衝突リンク候補に(a,v)を加える
- この時点で、キューには衝突する可能性のあるノード、リンクは全て入っている
- 第2段階では、最短距離の半分と、各ノード $v,a$ への到達時間を比べて、衝突候補を求めていく
次に、数え上げを行う。
shortest_count: 最短経路数collision_count: 衝突する経路の組み合わせ数
- 衝突候補経路につき、
- 本当に最短経路か、コストの整合性をチェック。不適なら飛ばす
- shortest_countに、その経路を通る経路数を加算
- collision_countに、その経路を通る経路数の2乗を加算
全ての衝突経路について数え上げ、最終的に(shortest_count$^2$-collision_count)が、衝突しない経路の組み合わせ数となる。
反省点は、経路探索中にノードを広げる条件を誤ったため、TLEを連発してしまった。
- 最短距離を求めるだけなら探索済みのノードは問答無用で飛ばせばいい
- 経路数を求める際は、探索済みであっても距離が等しければ経路数の数え上げのために飛ばしてはいけない
- だがノードを広げる処理は、最初の1回目に訪れた時のみにしないと、同じノードが繰り返しキューに積まれることになる
Dijkstraのアルゴリズムは汎用性が高いが、ノードを広げる条件に注意しないと、無駄な計算をさせてしまう。
from heapq import heappop, heappush
import sys
MOD, INF = 1000000007, float('inf')
def solve(s, t, links):
q = [(0, 0, s, -1, 0), (0, 0, t, -1, 1)]
visited_fwd, visited_bwd = [INF] * n, [INF] * n
patterns_fwd, patterns_bwd = [0] * (n + 1), [0] * (n + 1)
patterns_fwd[-1] = patterns_bwd[-1] = 1
collision_nodes, collision_links = set(), set()
limit = 0
while q:
cost, cost_a, v, a, is_bwd = heappop(q)
if is_bwd:
visited_self = visited_bwd
visited_opp = visited_fwd
patterns_self = patterns_bwd
else:
visited_self = visited_fwd
visited_opp = visited_bwd
patterns_self = patterns_fwd
relax_flag = False
cost_preceding = visited_self[v]
if cost_preceding == INF:
visited_self[v] = cost
relax_flag = True
elif cost > cost_preceding:
continue
patterns_self[v] += patterns_self[a]
if relax_flag:
cost_opp = visited_opp[v]
if cost_opp != INF:
limit = cost + cost_opp
if cost == cost_opp:
collision_nodes.add(v)
else:
collision_links.add((v, a) if is_bwd else (a, v))
break
for u, du in links[v].items():
nc = cost + du
if visited_self[u] < nc:
continue
heappush(q, (nc, cost, u, v, is_bwd))
collision_time = limit / 2
while q:
cost, cost_a, v, a, is_bwd = heappop(q)
if cost > limit:
break
visited_self = visited_bwd if is_bwd else visited_fwd
if visited_self[v] == INF:
visited_self[v] = cost
if is_bwd:
if cost == collision_time:
patterns_bwd[v] += patterns_bwd[a]
continue
if cost_a == collision_time:
collision_nodes.add(a)
elif cost == collision_time:
collision_nodes.add(v)
patterns_fwd[v] += patterns_fwd[a]
else:
collision_links.add((a, v))
shortest_count = 0
collision_count = 0
for v in collision_nodes:
if visited_fwd[v] == visited_bwd[v]:
r = patterns_fwd[v] * patterns_bwd[v]
shortest_count += r
shortest_count %= MOD
collision_count += r * r
collision_count %= MOD
for u, v in collision_links:
if visited_fwd[u] + visited_bwd[v] + links[u][v] == limit:
r = patterns_fwd[u] * patterns_bwd[v]
shortest_count += r
shortest_count %= MOD
collision_count += r * r
collision_count %= MOD
return (shortest_count ** 2 - collision_count) % MOD
n, m = map(int, input().split())
s, t = map(int, input().split())
s -= 1
t -= 1
links = [{} for _ in range(n)]
for uvd in sys.stdin.readlines():
u, v, d = map(int, uvd.split())
# for _ in range(m):
# u, v, d = map(int, input().split())
u -= 1
v -= 1
links[u][v] = d
links[v][u] = d
print(solve(s, t, links))

