zenpachi blog

競技プログラミングでよく使いそうな処理をいくつかまとめておく

このエントリーをはてなブックマークに追加

最近初めて競技プログラミング(AtCoder)をやってみたのですがなかなか難しいですね。
時間内に解けない、もしくは計算量をなかなか削減できずTLEになってしまうことが多かったので、いくつか問題をやってみた中でよく使いそうだなというものをまとめておきます。
自分用の学習メモなので目新しさなどはありません。

使用言語はPython3.4.3です。

$ python --version
Python 3.4.3

二分探索 / Binary search

ソート済の配列に対する探索アルゴリズムです。これは知ってはいるものの毎回ググってる気がします。
配列の中央の値を見ることで探索したい値が配列の右側にあるか左側にあるかを判定し、探索が必要ない半分を切り捨てていきます。
それを繰り返すことで範囲が毎回半分になっていくので計算量 O(logN)O(\log N) で探索できます。

配列に対してある数nを配列のソート順を崩さないように挿入するためのindexを返す関数を実装してみます。同一の数が配列に複数含まれていた場合は最も小さいindexを返します。

whileループを使う

def binary_search(sorted_list, n):
    min, max = 0, len(sorted_list) - 1
    while min <= max:
        mid = min + (max - min) // 2
        if sorted_list[mid] == n:
            if sorted_list[mid - 1] == n:
                max = mid
            else:
                return mid
        if sorted_list[mid] < n:
            min = mid + 1
        else:
            max = mid - 1
    return min

bisectを使う

pythonの場合bisectを使用すれば簡単にできます。

import bisect

sorted_list = [1, 1, 2, 3, 5, 5, 8, 13, 21]
print(bisect.bisect_left(sorted_list, 5)) # 4

bisect_rightを使用すると探索したい数の次のindexを返します。

import bisect

sorted_list = [1, 1, 2, 3, 5, 5, 8, 13, 21]
print(bisect.bisect_right(sorted_list, 5)) # 6

AtCoderの例題

AtCoder Beginner Contest 143 D - Triangles

全パターンの組み合わせを計算すると計算量O(N3)O(N^ 3)となり間に合いませんが、1本目と2本目の組み合わせに対して3本目を二分探索で探すようにすると、計算量O(N2logN)O(N^ 2 \log N)で答えを求めることができます。

優先度付きキュー / Priority queue

雑な理解ですが優先度付きキューは抽象データ型のひとつで、各要素に優先度がついており要素の取り出しや追加を、通常の配列などよりも少ない計算量で行うことができるようです。
優先度付きキューの実装としてはヒープがよく使われるようで、Pythonにもheapqという実装があるのでコードを抜粋して読んでみます。

cpython/heapq.py at 3.7 · python/cpython · GitHub (こちらは3.7のコードですが)

最初のコメントを読むと、ヒープは全てのkについて、a[k] <= a[2*k+1]およびa[k] <= a[2*k+2]が成り立つ配列で、a[0]が最小となりますとあります。 つまり

        0
     /     \
   1         2 
 /   \     /   \
3     4   5     6

のような木構造で子よりも親の方が常に小さいということですね。 これにより最小値を取り出す場合はO(1)O(1)ですし、要素の追加と削除はO(logN)O(\log N)で可能です。 実際のコードは以下

def heapify(x):
    """Transform list into a heap, in-place, in O(len(x)) time."""
    n = len(x)
    # Transform bottom-up.  The largest index there's any point to looking at
    # is the largest with a child index in-range, so must have 2*i + 1 < n,
    # or i < (n-1)/2.  If n is even = 2*j, this is (2*j-1)/2 = j-1/2 so
    # j-1 is the largest, which is n//2 - 1.  If n is odd = 2*j+1, this is
    # (2*j+1-1)/2 = j so j-1 is the largest, and that's again n//2-1.
    for i in reversed(range(n//2)):
        _siftup(x, i)


def _siftdown(heap, startpos, pos):
    newitem = heap[pos]
    # Follow the path to the root, moving parents down until finding a place
    # newitem fits.
    while pos > startpos:
        parentpos = (pos - 1) >> 1
        parent = heap[parentpos]
        if newitem < parent:
            heap[pos] = parent
            pos = parentpos
            continue
        break
    heap[pos] = newitem


def _siftup(heap, pos):
    endpos = len(heap)
    startpos = pos
    newitem = heap[pos]
    # Bubble up the smaller child until hitting a leaf.
    childpos = 2*pos + 1    # leftmost child position
    while childpos < endpos:
        # Set childpos to index of smaller child.
        rightpos = childpos + 1
        if rightpos < endpos and not heap[childpos] < heap[rightpos]:
            childpos = rightpos
        # Move the smaller child up.
        heap[pos] = heap[childpos]
        pos = childpos
        childpos = 2*pos + 1
    # The leaf at pos is empty now.  Put newitem there, and bubble it up
    # to its final resting place (by sifting its parents down).
    heap[pos] = newitem
    _siftdown(heap, startpos, pos)

上記の木構造おける、子要素を持つ各要素a[k]に対して、shiftupとshiftdownという処理を行っています。 shiftupはa[2*k+1]およびa[2*k+2](つまりa[k]に対する子要素)のうち小さい方をa[k]と入れ替え、それを木構造の末端に達するまで繰り返し行います。 shiftdownはshiftupで達した末端から、今度は親に向かって遡りながら子が親よりも小さかった場合に親と子を入れ替える処理を繰り返します。 これにより全てのkについて、a[k] <= a[2*k+1]およびa[k] <= a[2*k+2]が成り立つことになります。

使用してみるとこんな感じ

A = [5, 7, 1, 4, 3, 3, 2]
heapify(A)
print(A) # [1, 3, 2, 4, 7, 3, 5]

#         1
#       /   \
#     3       2
#   /  \    /  \
#  4    7  3    5

要素を追加する場合は、heapの最後に追加してshiftdownを行うことで計算量O(logN)O(\log N)で処理が完了します。 実際の実装は以下

def heappush(heap, item):
    """Push item onto heap, maintaining the heap invariant."""
    heap.append(item)
    _siftdown(heap, 0, len(heap)-1)

AtCoderの例題

AtCoder Beginner Contest 141 D - Powerful Discount Tickets

割引券は常にそのときの最も値段が高い品物に使用するべきですので、割引券を一枚使用するたびにその時点での最も値段が高い品物を判定していけばいいのですが、最大値の取得を愚直に計算すると計算量O(N)O(N)となってしまうため結果を得るためにはO(NM)O(NM)となり、TLEとなることが予想されます。 そこで優先度付きキューを使用することで計算量を削減して処理を行います。

逆元 / Inverse element

以下の記事が非常にわかりやすかったです。 記事を読めば全て事足りますし、そちらの方がよほど正確かつ詳細かとは思いましたが一応学習メモとして自分なりに理解した部分をまとめます。
「1000000007 で割ったあまり」の求め方を総特集! 〜 逆元から離散対数まで 〜 - Qiita

組み合わせなどの解が膨大な数になり得るような場合、「109+710^ 9+7で割った余りを求めよ」のようになっているケースがよくあります。 そのような場合、解が足し算や掛け算の場合であればオーバーフローや桁数が増加しすぎることによる計算時間の増大をを避けるために、一度計算を行うたびにあまりを取っていけばいいでしょう。

仮に13で割ったあまりを求める問題の場合、

res = 100 * 100 * 100 % 13 # 1

res = 100 * (100 * 100 % 13) % 13 # 1

にするようなイメージです。 しかし割算の場合は途中で計算を入れることができません

500 / (5 * 10) % 13 # 10.0
500 / (5 * 10 % 13) % 13 # 6.454545454545453

そこで、割算を行う際には「逆元」を計算する必要があるようです。
たとえば「9 / 3 を 13で割った余り」は 9÷3(mod13)9 ÷ 3 \pmod{13} と表せ、変形すると 9×(1÷3)(mod13)9 × (1 ÷ 3) \pmod{13} となります。

この (1÷3)(1 ÷ 3)mod13\mod{13} の場合における3の逆元となり313^{-1}と表します。
逆元を xx とすると 31x(mod13)3^ {-1} \equiv x \pmod{13} 、変形して 3x1(mod13)3x \equiv 1 \pmod{13} となります。 つまり逆元の意味としては「3を掛けて1になる数(mod13\mod{13}において)」です。

フェルマーの小定理によると、

pp が素数、かつ bbpp で割り切れない整数の場合 aa を整数として
bxa(modp)bx \equiv a \pmod{p}
を満たすxが存在する

ということなので、3x1(mod13)3x \equiv 1 \pmod{13} を満たすxxが存在することがわかります。 mod13\mod{13}において3の逆元(313^ {-1})は9なので、

9 * 9 % 13   # 3 (9 / 3 % 13 と等しい)

となります。

逆元の求め方

フェルマーの小定理を使用する方法と拡張ユークリッドの互除法を使用する方法があるようです。
以下はこちらの記事で紹介されている拡張ユークリッドの互除法を使用したコード。
Pythonでモジュラ逆数を求める

def xgcd(a, b):
    x0, y0, x1, y1 = 1, 0, 0, 1
    while b != 0:
        q, a, b = a // b, b, a % b
        x0, x1 = x1, x0 - q * x1
        y0, y1 = y1, y0 - q * y1
    return a, x0, y0

def modinv(a, m):
    g, x, y = xgcd(a, m)
    if g != 1:
        raise Exception('modular inverse does not exist')
    else:
        return x % m


print(modinv(3, 13)) # 9
print(modinv(4, 13)) # 10
print(modinv(5, 13)) # 8
print(modinv(6, 13)) # 11
print(modinv(7, 13)) # 2
print(modinv(8, 13)) # 5
print(modinv(9, 13)) # 3

二項係数 nCr(mod1000000007)_{n}{\rm C}_{r} \pmod{1000000007}

nCr_{n}{\rm C}_{r}nCr=n!r!(nr)!=(n!)(r!)1((nr)!)1{}_{n}{\rm C}_{r} = \frac{n!}{r!(n-r)!} = (n!)(r!)^{-1}((n-r)!)^{-1} ですので r!r! の逆元と (nr)!(n-r)! の逆元が求められれば良いことになります。

実装はこちらの記事を参考にさせていただきました。
逆元テーブルを予め作成しておくことで高速に処理が行えるようです。
よくやる二項係数 (nCk mod. p)、逆元 (a^-1 mod. p) の求め方 - けんちょんの競プロ精進記録

mod = 10**9+7
fac = [1, 1]
finv = [1, 1]
inv = [0, 1]


def init(n):
    for i in range(2, n + 1):
        fac.append(fac[-1] * i % mod)
        inv.append(-inv[mod % i] * (mod // i) % mod)
        finv.append(finv[-1] * inv[-1] % mod)


def com(n, k, mod):
    if n < 0 or k < 0 or n < k:
        return 0
    return fac[n] * (finv[k] * finv[n - k] % mod) % mod

init(100)
print(com(100, 3, mod)) # 161700

AtCoderの例題

AtCoder Beginner Contest 145 D - Knight

連立方程式でどちらの移動方法を何回行ったのかまではすんなり求められるのですが、そのあとの組み合わせの計算でTLEになりました。
上記のような方法を使うことで計算量を削減し、TLEを避けることができます。

参考文献

このエントリーをはてなブックマークに追加