AGC019 の C 問題を python で解いてみる。

本記事では先日開催された AtCoder Grand Contest 019 の C 問題の解説を試みます.
ちなみに私は本番中には解けず, 解説や他の人のコードを参考にさせていただきました.

問題はこちら
agc019.contest.atcoder.jp


スタート地点からゴール地点までの最短ルートの長さを求める問題ですね. 例からはなんとなく,
1. 最短ルートはスタート地点とゴール地点を向かい合う角とする長方形の内側(と周辺)を通る
2. ゴールが北東にある場合は, 南, 西方向に進むことはない
3. なるべく噴水に沿って90度回転する
ということが読み取れそうです. また,
4. 直線ルートよりも遠くなってしまうにも関わらず噴水に沿って180度回転する場合もある
ということも言えそうです.
これらの検証はやらないことにして, コアになるコードを作っていきます.
詳しくロジックを理解したい方は,
https://atcoder.jp/img/agc019/editorial.pdf
を参考にしてください.

まずは入力編

import math 
x1, y1, x2, y2= map(int, input().split())
N = int(input())
W = abs(x2 - x1)
H = abs(y2 - y1)
xsgn = 2*(x2>x1) - 1
ysgn = 2*(y2>y1) - 1
XY = []
for i in range(N):
    x, y=map(int, input().split())
    if (x-x1)*(x-x2)<=0 and (y-y1)*(y-y2)<=0:
        XY+=[[xsgn*(x-x1), ysgn*(y-y1)]]

こんな感じでやってみました.
スタート地点が原点, ゴール地点が第一象限にくるように整形しています.
ついでに, 通り得る噴水として考察1. の条件を満たすものだけをリストに加えます, それ以外は通らないので.

次に計算編

short = 20 - 5 * math.pi
long = 10 * math.pi - 20
straight = 100*(W + H)
XY.sort()
Y = [y for x,y in XY]
 
fountain = LIS(Y)

short は噴水に沿って90度回転した場合, 直線ルートと比べてどれだけ短くなるか
long は噴水に沿って180度回転した場合, 直線ルートと比べてどれだけ長くなるか
を表します. straight は噴水が一切ない場合の最短ルートの長さを表しています.
考察2. により, 最短ルートにおいては噴水 A を通った後に噴水 B を通るためには B の x, y 座標はともにAのそれより大きくなくてはならないという性質があります. 一方で, なるべく多くの噴水に沿って90度で回転するのが最短ルートの候補ですから, その性質をみたすような噴水の組み合わせの中で, 最も噴水の数が多いものを見つけようということになります.
上記の性質において噴水は全順序にはならないため, x 座標について昇順に並べ替え, そのリストの中で y 座標について昇順になり得る個数はどれだけあるかということを考えてます.
また, ここでLISとは前回記事で作った, リストに渡すと最長増加部分列を返す関数です.
つまり, 今の場合, 最長増加部分列の長さが最短ルートにおいて通る噴水の数となっています.

あとは出力編

if fountain < min(W, H)+1:
    print(straight - short * fountain)
else:
    print(straight - short * (fountain - 1) + long)

実は場合分けが必要です. 同じ行, 列には高々1つしか噴水がないことから, 最短ルートにおいて
 \text{噴水の数}\leqq \min{(|x_2-x_1|, |y_2-y_1|)}+1
が成り立っており, 特に等号が成立する時は1度は180度回転する必要があります.
逆にそうでない場合はそれらの噴水を90度曲がるように選ぶことが出来, これが最短であることはわかるでしょう.

前回記事
wakabame.hatenablog.com
で紹介した最長増加部分列を O(n\log{n}) で求めるアルゴリズムとあわせて,

import math


def LIS(L):
    from bisect import bisect_left

    seq = []
    for i in L:
        pos = bisect_left(seq, i)
        if len(seq) <= pos:
            seq.append(i)
        else:
            seq[pos] = i
    return len(seq)


x1, y1, x2, y2 = map(int, input().split())
N = int(input())
W = abs(x2 - x1)
H = abs(y2 - y1)
xsgn = 2 * (x2 > x1) - 1
ysgn = 2 * (y2 > y1) - 1
XY = []
for i in range(N):
    x, y = map(int, input().split())
    if (x - x1) * (x - x2) <= 0 and (y - y1) * (y - y2) <= 0:
        XY += [[xsgn * (x - x1), ysgn * (y - y1)]]
short = 20 - 5 * math.pi
long = 10 * math.pi - 20
straight = 100 * (W + H)

XY.sort()
Y = [y for x, y in XY]

fountain = LIS(Y)

if fountain < min(W, H) + 1:
    print(straight - short * fountain)
else:
    print(straight - short * (fountain - 1) + long)

を submit すればめでたく AC となります.
Submission #23288743 - AtCoder Grand Contest 019

また, 最長増加部分列を O(n^2) で求めてしまうと時間内に解き終わることができません.
Submission #1548728 - AtCoder Grand Contest 019