目次
AtCoder Beginner Contest 157 D,E,F 問題メモ
D - Friend Suggestions
問題
- SNSに $N$ 人が登録していて、番号 $1,...,N$ が振られている
- SNSでは、「友達関係」「ブロック関係」の2種類の関係がある
- 友達関係は $M$ 個あり、$i$ 番目では人 $a_i$ と $b_i$ が友達関係である
- ブロック関係は $K$ 個あり、$i$ 番目では人 $c_i$ と $d_i$ がブロック関係である
- 同じ2人が友達関係でもブロック関係でもあるような入力は与えられない
- ある2人が「友達候補」とは、以下の条件に当てはまる関係を指す
- 直接の友達関係ではないが、友達関係を辿っていくと繋がっている
- ブロック関係ではない
- 各人 $1~N$ の、「友達候補」となる人数を求めよ
- $2 \le N \le 10^5$
- $0 \le M,K \le 10^5$
解法
まず、「直接の友達関係・ブロック関係ではない」条件を無視すると、友達を辿って繋がっている閉じた関係(連結成分)の中では、友達候補の人数は共通である。
この人数は、Union-Find木で管理できる。(根に、自身の連結成分の人数を持たせる)
そこから、問題文の条件通りの友達候補を求める。人 $i$ の友達候補を求めるには、以下のようにすればよい。
- $i$ の属する連結成分の人数 - 1(自分) - 直接の友達 - 連結成分の中でブロック関係にある人数
まず直接の友達を引く。これは $i$ の次数を数えればよい。つまり、$M$ 個の友達関係の中で $i$ が出てきた回数を引けばよい。
次にブロック関係だが、たとえブロック関係にあっても友達を介して繋がっていない人はそもそも引かれる数に含まれてないので、 あくまで「連結成分の中で」ブロック関係にある人数を求めなければならない。
これは $i$ とブロック関係にある $j$ のそれぞれについて、構築済みのUnion-Findで $i$ と $j$ が同一連結成分に属するかを調べればよい。
計算量は、Union-Findの構築、ブロック関係の調査でそれぞれ1回あたり $O(\alpha(N))$、合計で $O((M+K) \alpha(N))$ となる。
import sys class UnionFind: def __init__(self, n): self.table = [-1] * n def _root(self, x): if self.table[x] < 0: return x else: self.table[x] = self._root(self.table[x]) return self.table[x] def count(self, x): return -self.table[self._root(x)] def find(self, x, y): return self._root(x) == self._root(y) def union(self, x, y): r1 = self._root(x) r2 = self._root(y) if r1 == r2: return d1 = self.table[r1] d2 = self.table[r2] if d1 <= d2: self.table[r2] = r1 self.table[r1] += d2 else: self.table[r1] = r2 self.table[r2] += d1 n, m, k = map(int, sys.stdin.buffer.readline().split()) abcd = list(map(int, sys.stdin.buffer.read().split())) uft = UnionFind(n) friends = [0 for _ in range(n)] blocks = [set() for _ in range(n)] for a, b in zip(abcd[0:2 * m:2], abcd[1:2 * m:2]): a -= 1 b -= 1 uft.union(a, b) friends[a] += 1 friends[b] += 1 for c, d in zip(abcd[2 * m::2], abcd[2 * m + 1::2]): c -= 1 d -= 1 blocks[c].add(d) blocks[d].add(c) ans = [] for i in range(n): ans.append(uft.count(i) - 1 - friends[i] - sum(uft.find(i, j) for j in blocks[i])) print(*ans)
E - Simple String Queries
問題
- 長さ $N$ の英小文字からなる文字列 $S$ がある
- 以下の $Q$ 個のクエリに答えよ
- クエリは2種類からなる
- $1, i_q, c_q$
- $S$ の $i_q$ 番目の文字を $c_q$ に変更する
- $2, l_q, r_q$
- $S$ の $l_q~r_q$ 文字(両端含む)にある文字の種類数を出力する
- $1 \le N \le 200000$
- $1 \le Q \le 50000$
解法
やることは見えやすいけど、実装で方向を間違えた。
まず本番中に実装したものを下記に書き、その後により高速……もとい想定解通りの解法を概略する。
セグメント木に、使っている文字の種類とその回数を保持して更新していけばよい。
{c1: v1, c2: v2, …}
: 自身の管理する区間上に、文字 $c_1$ が $v_1$ 回、$c_2$ が $v_2$ 回…出現する
更新するときは、末端を見れば $i_q$ 番目の元の文字が何だったか分かる($d_q$ とする)ので、自身と全ての親の $c_q$ を+1、$d_q$ を-1する。
それによって $d_q$ が0になれば、削除しておく。
出現文字数を $k$ として、1クエリの更新、取得をそれぞれ $O(k \log{N})$ でおこなえるので、全体で $O(kQ\log{N})$ となり、単純に制約を代入すると約 $10^7$ となる。
更新の簡便さを考えるとノードに持たせるのはdefaultdict()がよかったのだが、いかんせんセグ木に乗せる最大約100万個のオブジェクトを作るのが重たい。
defaultdict() では、キーが無かったときにデフォルト値を生成する関数を設定するのだが、
整数を設定したいときは int
よりも lambda: 0
の方が若干速いという情報をどこかで耳にし、それが習慣づいていた。
しかし、今回のように大量のdefaultdictを生成するときは、その書き方では毎回異なる関数が生成・紐付けられ、無駄が増える。
× defaultdict(lambda: 0) ○ defaultdict(int) ○ lmd0 = lambda: 0 defaultdict(lmd0)
上記の改善により、ギリギリ間に合うようになる。
import sys from collections import defaultdict class SegTree: """ 今回の問題に特化したセグ木 """ def __init__(self, n, s): """ :param n: 文字列長 :param s: 初期文字列 """ n2 = 1 << (n - 1).bit_length() self.offset = n2 self.tree = [None] * n2 + [{} for _ in range(n2)] lmd0 = lambda: 0 for i, c in enumerate(s): x = ord(c) - 97 i += self.offset self.tree[i][x] = 1 for i in range(self.offset - 1, 0, -1): sti = self.tree[i] = defaultdict(lmd0, self.tree[i << 1]) stj = self.tree[(i << 1) + 1] for x in stj: sti[x] += stj[x] # print(*(dict(t) for t in self.tree), sep='\n') def update(self, i, x): """ i番目の値をxに変更 :param i: index(0-indexed) :param x: update value """ i += self.offset if x in self.tree[i]: return y, _ = self.tree[i].popitem() self.tree[i][x] = 1 i >>= 1 while i > 0: sti = self.tree[i] sti[x] += 1 if sti[y] == 1: del sti[y] else: sti[y] -= 1 i >>= 1 def get_types(self, a, b): """ [a, b)の文字種数を得る :param a: index(0-indexed) :param b: index(0-indexed) """ result = set() l = a + self.offset r = b + self.offset while l < r: if r & 1: result.update(self.tree[r - 1].keys()) if l & 1: result.update(self.tree[l].keys()) l += 1 l >>= 1 r >>= 1 return len(result) n = int(input()) s = input() st = SegTree(n, s) q = int(input()) ans = [] for line in sys.stdin: i, j, k = line.rstrip().split() if i == '1': j = int(j) - 1 k = ord(k) - 97 st.update(j, k) else: l = int(j) - 1 r = int(k) ans.append(st.get_types(l, r)) print('\n'.join(map(str, ans)))
より高速な解法1
上記ではセグ木に辞書を乗せて出現回数まで記憶したが、単純に「ある区間に文字 $c$ が存在するか」の26個のフラグでも十分に管理できる。
うん、言われれば確かに。 何故か文字が無くなったかどうかを判定するのに、出現回数が必要と思ってしまった。
それならノードに持たせるのは整数(bitフラグ)1つで済み、マージも bitwise_or 1本でできるので、十分に間に合う。
反省として、セグ木に何の情報を持たせなければいけないかを考えるときに、「参照できる情報の範囲」を勘違いしてしまうのをよくやる。
[ * ] [ ][ * ] [ ][ ][ * ][ ] [ ][ ][ ][ ][ ][*][ ][ ]
一番下の[*]を更新したら、更新する親ノードは“*”をつけた箇所になるが、 「更新の情報源となる」ノードもこれだけだと思い込んでしまうため、上記のような勘違いをする。
[ * ] [ o ][ * ] [ ][ ][ * ][ o ] [ ][ ][ ][ ][o][*][ ][ ]
実際には、親を更新する際には兄弟ノード“o”の情報も使っていいため、それを使えばより簡単な情報で更新が可能になるはず。
より高速な解法2
文字種ごとにBinary Indexed Treeを作ってやると、ある文字を区間 $[l, r)$ で使っているかの判定は単純な $O(\log{N})$ クエリ2回で済むので、この方法でもよい。
計算量は表記上は変わらないが、保持するデータが単純になるため更新も速くなり、1本のセグ木上で無理くり辞書で管理するより高速になる。
より高速な解法3
平衡二分探索木を26個作ることでも解ける。
各文字について、最初、出現位置のindexを登録しておく。
変更クエリが来たら、変更前の文字の探索木からindexを削除、変更後の探索木に挿入。
取得クエリが来たら、各文字、lower_bound(l) と lower_bound(r+1) を計算し、同じなら区間にその文字は存在しない。
F - Yakiniku Optimization Problem
問題
- 鉄板上に肉が $N$ 枚ある
- $i$ 番目の肉は座標 $(x_i,y_i)$ にあり「焼けにくさ」は $c_i$
- 熱源が1個だけあり、好きな位置における
- 熱源を $(X,Y)$ に置くと、$i$ 番目の肉は $c_i \times (熱源と肉の直線距離)$ 秒後に焼ける
- うまく熱源を置いて少なくとも $K$ 枚の肉が焼けるまでの時間を最小化したいとき、その最小値を求めよ
- $1 \le K \le N \le 60$
- $−1000 \le x_i, y_i \le 1000$
- $1 \le ci \le 100$
- 許容誤差は、絶対または相対誤差のいずれかで $10^{-6}$ 以内
解法
最小包含円の問題と似てるようで似てないようで似ている。
少なくとも $K$ 個の点を包含する円の中心点に熱源を置く感じだが、肉ごとに重みが付くので、単純な中心ではない。
しかしそれでも考え方は最小包含円と同じく、条件を満たす熱源は以下のいずれかとなるはずである。
- ① $K$ 個中、最も時間がかかる2個の点を直径に持つ円の中心
- ② $K$ 個中、最も時間がかかる3個の点からの距離が等しい点
ただし、ここでの「距離」は、$c_i$ による重みを付けた上での距離を指し、「円の中心」もそれを考慮した意味とする。
この証明は、上のいずれでも無い場合、熱源を暫定最も時間がかかる点にわずかに近づけることで、必ず改善することから言える。
従って、以下の要領で調べていけばよい。
- 最も時間がかかると仮定する2点の組み合わせを全探索する($i,j$ とする)
- 2点を $c_j:c_i$ に内分する点が①に合致する
- ①に熱源を置き、$i,j$ への“距離”を $t$ とすると、$t$ 以内に収まる点が $K$ 個以上あるか調べる
- ある場合、暫定解と比較し、小さければ暫定解を更新する → continue(次の2点の組み合わせへ)
- ない場合、$t$ 以内に収まらなかった点を候補として新たに加える1点を全探索する($k$ とする)
- 3点からの“距離”が等しい点があるか調べる。ある場合、②と合致する
- ②に熱源を置き、3点への“距離”を $t$ とすると $t$ 以内に収まる点が $K$ 個以上あるか調べる
- ある場合、暫定解と比較・更新
計算量は、最悪、全ての3点の組み合わせが探索され、それぞれの内部でさらに $N$ 点との距離を計算するため、$O(N^4)$ となる。
3点からの“距離”が等しい点は、$c_i,c_j,c_k$ の関係によって計算が異なる。
$c_i = c_j = c_k$ の場合、中心点は、外心となる。
それ以外の場合、2点からのユークリッド距離の比が一定である点の集合は、アポロニウスの円 を描く。 アポロニウスの円は少なくとも2個描かれるので、その2円の交点が、3点からの“距離”が等しい点となる。
ただし、2円の位置と半径によっては、3点からの距離が等しい点が存在しないこともある。その場合はスキップしてよい。
import sys from itertools import combinations from math import atan2, cos, sin def apollonius(p1, c1, p2, c2): m1 = p1 + (p2 - p1) * c2 / (c1 + c2) m2 = p1 + (p2 - p1) * c2 / (c2 - c1) m = (m1 + m2) / 2 r = abs(m1 - m) return m, r def apollonius_intersections(p1, c1, p2, c2, p3, c3): m12, r12 = apollonius(p1, c1, p2, c2) m13, r13 = apollonius(p1, c1, p3, c3) v = m13 - m12 d = abs(v) if d > r12 + r13 or d < abs(r12 - r13): return None, None theta = atan2(v.imag, v.real) xx = (r12 ** 2 - r13 ** 2 + d ** 2) / (2 * d) s = (r12 + r13 + d) / 2 yy = 2 * (s * (s - r12) * (s - r13) * (s - d)) ** 0.5 / d st = sin(theta) ct = cos(theta) e1 = (xx * ct - yy * st) + (xx * st + yy * ct) * 1j + m12 e2 = (xx * ct + yy * st) + (xx * st - yy * ct) * 1j + m12 t1 = abs(e1 - p1) * c1 t2 = abs(e2 - p1) * c1 if t1 < t2: return e1, t1 return e2, t2 def get_circumscribed_center(p1, p2, p3): A, B, C = p2 - p3, p3 - p1, p1 - p2 A = (A * A.conjugate()).real B = (B * B.conjugate()).real C = (C * C.conjugate()).real T = A * (B + C - A) U = B * (C + A - B) W = C * (A + B - C) if T + U + W != 0: return (T * p1 + U * p2 + W * p3) / (T + U + W) else: # 直線上に並んでいるなど、外心が定義できない return None def solve(n, k, xyc): if k == 1: return 0 xxx = xyc[0::3] yyy = xyc[1::3] ccc = xyc[2::3] beefs = [(i, x + y * 1j, c) for i, (x, y, c) in enumerate(zip(xxx, yyy, ccc))] ans = 1e18 for (i1, p1, c1), (i2, p2, c2) in combinations(beefs, 2): # 肉2枚の比重を考慮した等距離点に熱源を置いたときの時間tを算出 # →t以内に焼ける肉がK以上あれば、その2枚を含む場合の最小時間はt px = p1 + (p2 - p1) * (c2 / (c1 + c2)) t = abs(px - p1) * c1 + 1e-8 ok = 0 ng = [] for i, p, c in beefs: if abs(px - p) * c <= t: ok += 1 else: ng.append(i) if ok >= k: ans = min(ans, t) continue # p1,p2の2点を含む場合の時間は、これ以上小さくならない if ans <= t: continue # K枚以上ない場合、範囲外のどれかの肉をギリギリ含めることを考える # (重複を防ぐため、j > i2 の肉を調べる) # 2点からの比が等しい点の軌跡は、c1==c2なら垂直二等分線、それ以外はアポロニウスの円 # p1-p2, p2-p3, p3-p1 の3円(または直線)が1点で交わる箇所があるか # ある → そこに熱源を置き、時間tを算出、t以内に焼ける肉がK以上あるか調べる # ない → 他の2点または3点で考えた方がよい for i3 in ng: if i3 < i2: continue _, p3, c3 = beefs[i3] if c1 == c2: if c2 == c3: e = get_circumscribed_center(p1, p2, p3) t = abs(e - p1) * c1 if e is not None else None else: e, t = apollonius_intersections(p3, c3, p1, c1, p2, c2) elif c1 == c3: e, t = apollonius_intersections(p2, c2, p1, c1, p3, c3) else: e, t = apollonius_intersections(p1, c1, p2, c2, p3, c3) if e is None: continue t += 1e-8 ok2 = 0 for i, p, c in beefs: if abs(e - p) * c <= t: ok2 += 1 if ok2 >= k: ans = min(ans, t) return ans n, k = map(int, input().split()) xyc = list(map(int, sys.stdin.read().split())) print(solve(n, k, xyc))
この問題を解く過程で調べたところ、Pythonの複素数において絶対値の2乗を算出したい場合は、
- 素のPythonで行う場合は共役複素数をかけるのが速いし、精度も良い
- PyPyで行う場合は、素直に実部と虚部の2乗の和が速い
- NumPyで行う場合は、素直に実部と虚部の2乗の和が速い
まぁ複素数でxy座標を表現するような問題でこれを計算したい状況というのは、もっと煩雑な座標計算を伴うはずで、この部分の改善だけで差が付くほどでは無いと思うが。
import timeit loop = 1000 setup = 'import random; a = [random.randrange(100) + random.randrange(100) * 1j for _ in range(10000)]', t1 = timeit.timeit('[c.real ** 2 + c.imag ** 2 for c in a]', setup=setup, number=loop) print(t1) t2 = timeit.timeit('[(c * c.conjugate()).real for c in a]', setup=setup, number=loop) print(t2) t3 = timeit.timeit('[abs(c) ** 2 for c in a]', setup=setup, number=loop) print(t3) # => Python # 素直に2乗 4.3393597 # 共役複素数 1.3621795999999993 # 絶対値2乗 3.0097560000000003 精度悪 # => PyPy # 0.12356209754943848 # 0.12764811515808105 # 0.3994109630584717
import timeit import random import numpy as np a = [random.randrange(100) + random.randrange(100) * 1j for _ in range(10000)] a = np.array(a, dtype=np.complex128) # 素直に2乗 def each(a): return a.real ** 2 + a.imag ** 2 # 共役複素数を書けて実部を取る def conj(a): return (a * a.conjugate()).real # 絶対値の2乗 def abs2(a): return np.abs(a) ** 2 loop = 10000 t1 = timeit.timeit('each(a)', globals=globals(), number=loop) print(t1) t2 = timeit.timeit('conj(a)', globals=globals(), number=loop) print(t2) t3 = timeit.timeit('abs2(a)', globals=globals(), number=loop) print(t3) # => # 素直に2乗 0.2455292 # 共役複素数 0.309558 # 絶対値2乗 0.8267248999999999