ある問題が解けるかどうかは、直近に似た解法の問題を解いたかどうかに大きく依存する
数字1個1個が小さいと、たとえば(1,2,6)と(4,5)のように「3個選んだ結果と2個選んだ結果が一致してしまう」ことがあるが、今回の場合は考えなくていい。
10^100は長すぎるので、便宜的に10^7で表す 3個の数の和の例 10000001 + 10000002 + 10000006 = 30000009 これを2個の数字で作ろうと思えば、 10000??? + 10000??? = 30000009 ↓ ??? + ??? = 10000009 ???の部分は0~N(≦10^5)の数だが、その範囲の2つの数を足して、10^100以上の数は作れない。 他の個数で考えても、0~10^5の数を1~10^5個足し合わせることによって、10^100は作れない。
なので、選ぶ個数を固定して、それ毎に結果を足し合わせればよい。
$k$ 個の数を選ぶとき、各項に含まれる $10^{100}$ は除いて考えても和の個数を求める上では問題ない。
数字は1刻みで、特に選び方に制約も無いので、最小値~最大値の間の整数は全て作れる。
# 累積和解法 from itertools import accumulate n, k = map(int, input().split()) MOD = 10 ** 9 + 7 f_acc = [0] + list(accumulate(list(range(n + 1)))) b_acc = [0] + list(accumulate(reversed(list(range(n + 1))))) ans = 0 for i in range(k, n + 2): min_ = f_acc[i] max_ = b_acc[i] ans = (ans + max_ - min_ + 1) % MOD print(ans)
活発度の大きい子をなるべく移動させたいので、大きい順に、左詰または右詰で遠い方に置いていくのがよさそう。
しかし、サンプル2などでは、
初期 5 5 6 1 1 1 結果 並び① 1 1 1 5 5 6 → 84 並び② 6 1 1 1 5 5 → 85
となり、最大である“6”は右詰にした方が移動距離が多いにもかかわらず、左詰にして2つの“5”の移動距離を1ずつ伸ばした方が、結果的に答えは大きくなる。
貪欲は無理そうなので、DPを使う。
戦略としては当初の想定通り、大きい方から、今置ける限り左詰、右詰のどちらかに置いていく、というのが最適である。(証明は解説pdf参照。この証明ができなかった)
ただ、後続の活発度や初期位置によって、左右どちらを選べばよいかが分からないので、両方試した結果をDPで管理していきましょう、ということ。
活発度が大きい方から処理するとして、$k$ 番目の子供のスコアは「今、左に何個、右に何個置いたか」を持っておけば、 既に置いた具体的な数字の中身がどうであれ、関係なく計算することができる。
def solve(n, aaa): dp = [0] aaa_with_idx = list(zip(aaa, range(n))) aaa_with_idx.sort(reverse=True) for k in range(1, n + 1): ndp = [0] * (k + 1) a, i = aaa_with_idx[k - 1] for l in range(k): ndp[l + 1] = max(ndp[l + 1], dp[l] + a * abs(i - l)) r = n - (k - l) ndp[l] = max(ndp[l], dp[l] + a * abs(i - r)) dp = ndp return max(dp) n = int(input()) aaa = list(map(int, input().split())) print(solve(n, aaa))
何をすればよいかは分かったが、計算量を $O(N^2)$ から減らす方法がわからなかった。
まず、こういった問題は、全体から条件を満たさないものを引くことで考えるとやりやすい。
こんな木があって、色が●である頂点に着目するとき、 ○--●-, ,-○ ○--●--○ ○--○-' `-○ ●を消して、森に分割すると、 ○ ,-○ ○ ○ ○--○-' `-○ 同じ連結成分の中でのパスは必ず●は通らず、 異なる連結成分(または●)間のパスは必ず●を通る。
同じ連結成分の中での2頂点ペアの個数は、そのサイズを $s$ として、$\dfrac{s(s+1)}{2}$ 個なので、各連結成分毎にこれを全体 $\dfrac{N(N+1)}{2}$ から引けばよい。
で、どのように数え挙げるか。
1色($c$ とする)だけを考えればよいのであれば、適当な頂点 $r$ を根として、以下のような木DPを行うことが思いつく。
すると、DFSで葉から順にこれは求まっていき、
節の頂点 $v$ の子が $u_1,u_2,...$ とすると、
根 $r$ まで遡って $DP[r]$ が求まったら $r$ を含む連結成分は上記で考慮できてないので、最後に $\dfrac{DP[r](DP[r]+1)}{2}$ を引く。
これで色 $c$ についての答えが求まった。
しかしこれには1色で $O(N)$ かかり、全ての色ごとにやったり、$DP[v][c]$ のようにDPの次元を増やすと、$O(N^2)$ となり間に合わない。
$DP[v][c]$ と次元を増やすのは増やすが、計算量を $O(N \log{N})$ 程度に抑える。
ボトルネックとなる場所は、$v$ において子 $u_1,u_2,...$ の情報をマージする箇所である。
単純にDPの次元を増やした場合の擬似コード:(以下をvごとに行う必要がある) for u in (u1, u2, ...): | for c in 1..N: | (1) 各子・各色について結果をマージ DP[v][c] += DP[u][c] | for c in 1..N: | (2) 各色についてv自身の頂点数を加算 DP[v][c] += 1 | DP[v][color[v]] = 0
(1)の処理について、木の末端に近いところでは色の種類数もそこまで多くないので、部分木以下に出てくる必要な色だけを持たせたい。
そのため、DPの2次元目($c$ の部分)は辞書で持ち、必要な色だけを管理する。
辞書などの集合をマージする際、サイズの大きい方をベースに小さい方を加えることで計算量がならし $O(N \log{N})$ に抑えられる(「マージテク」というらしい)。
しかしそれを使っても、(2)で $color[v]$ 以外の色について、頂点 $v$ の分を1ずつ加算する処理は避けられない。
これは、DPで管理する情報を「$v$ から色 $c$ を通らずに行ける頂点数」にしているから毎回全ての色について更新の必要があるのであって、 2つに分けて「部分木の頂点数」と「色 $c$ を通らずには行けない頂点数」として持てばよい。
$旧DP[v][c]=size[v]-新DP[v][c]$ となる。こうすれば size[v]
に対してのみ+1すればよくなる。
擬似コード: for u in (u1, u2, ...): if DP[v].size < DP[u].size: # マージテクによる入れ替え swap(DP[v], DP[u]) for c in DP[u].keys(): # 辞書で持ち、必要な色のみ更新 DP[v][c] += DP[u][c] size[v] += size[u] size[v] += 1 DP[v][color[v]] = size[v]
import sys from collections import defaultdict sys.setrecursionlimit(10 ** 6) def dfs(v, p, ccc, links, ans): ret_colors = defaultdict(int) ret_count = 1 cv = ccc[v] for u in links[v]: if u == p: continue sub_colors, sub_count = dfs(u, v, ccc, links, ans) cc = sub_count - sub_colors[cv] ans[cv] -= cc * (cc + 1) // 2 ret_count += sub_count if len(ret_colors) < len(sub_colors): ret_colors, sub_colors = sub_colors, ret_colors for c, cnt in sub_colors.items(): ret_colors[c] += cnt ret_colors[cv] = ret_count return ret_colors, ret_count def solve(n, ccc, links): if n == 1: return [1] all_pair = n * (n + 1) // 2 ans = [all_pair] * (n + 1) colors, count = dfs(0, -1, ccc, links, ans) assert count == n for c in range(1, n + 1): cc = n - colors[c] ans[c] -= cc * (cc + 1) // 2 return ans[1:] n, *cab = map(int, sys.stdin.buffer.read().split()) ccc = cab[:n] links = [set() for _ in range(n)] for a, b in zip(cab[n + 0::2], cab[n + 1::2]): a -= 1 b -= 1 links[a].add(b) links[b].add(a) print('\n'.join(map(str, solve(n, ccc, links))))