kumilog.net

データ分析やプログラミングの話などを書いています。

ダイクストラ法

はじめに

最短経路を求めるアルゴリズムであるダイクストラ法をPythonで実装して、例題を問いてみます。

ダイクストラ法

ダイクストラ法は最短経路を効率的に求めるアルゴリズムで、辺の重みが非負のときに使うことができます。詳細は以下の記事を参考にしてください。

計算量

優先度付きキューを使うことで、O((V+E)logV)で求めることができます*1

コード

Pythonで優先度付きキューはheapqを用いて実装できます。

from collections import defaultdict
from heapq import heappop, heappush


class Graph(object):
    """
    隣接リストによる有向グラフ
    """

    def __init__(self):
        self.graph = defaultdict(list)

    def __len__(self):
        return len(self.graph)

    def add_edge(self, src, dst, weight=1):
        self.graph[src].append((dst, weight))

    def get_nodes(self):
        return self.graph.keys()


class Dijkstra(object):
    """
    ダイクストラ法(二分ヒープ)による最短経路探索
    計算量: O((E+V)logV)
    """

    def __init__(self, graph, start):
        self.g = graph.graph

        # startノードからの最短距離
        # startノードは0, それ以外は無限大で初期化
        self.dist = defaultdict(lambda: float('inf'))
        self.dist[start] = 0

        # 最短経路での1つ前のノード
        self.prev = defaultdict(lambda: None)

        # startノードをキューに入れる
        self.Q = []
        heappush(self.Q, (self.dist[start], start))

        while self.Q:
            # 優先度(距離)が最小であるキューを取り出す
            dist_u, u = heappop(self.Q)
            if self.dist[u] < dist_u:
                continue
            for v, weight in self.g[u]:
                alt = dist_u + weight
                if self.dist[v] > alt:
                    self.dist[v] = alt
                    self.prev[v] = u
                    heappush(self.Q, (alt, v))

    def shortest_distance(self, goal):
        """
        startノードからgoalノードまでの最短距離
        """
        return self.dist[goal]

    def shortest_path(self, goal):
        """
        startノードからgoalノードまでの最短経路
        """
        path = []
        node = goal
        while node is not None:
            path.append(node)
            node = self.prev[node]
        return path[::-1]

実行例

f:id:xkumiyu:20180708195858p:plain

出典: ダイクストラ法(最短経路問題)

こちらのグラフのノード sからノード gまでの最短経路を求めてみます。

# (src, dst, weight)
inputs = [(0, 1, 5), (0, 2, 4), (0, 3, 2), (1, 2, 2), (1, 5, 6), (2, 3, 3),
         (2, 4, 2), (3, 4, 6), (4, 5, 4)]

g = Graph()
for src, dst, weight in inputs:
    g.add_edge(src, dst, weight)
    g.add_edge(dst, src, weight)

d = Dijkstra(g, 0)
print('最短経路: {}'.format(d.shortest_path(5)))
print('最短距離: {}'.format(d.shortest_distance(5)))
最短経路: [0, 2, 4, 5]
最短距離: 10

例題

問題

kenkoooo さんはすぬけ国での旅行の計画を立てています。 すぬけ国には n個の都市があり、m本の電車が走っています。 都市には 1から nの番号がつけられていて、 i番目の電車は都市uiとviの間を両方向に走っています。 どの都市からどの都市へも電車を乗り継ぐことで到達できます。

すぬけ国で使える通貨には、円とスヌークの2種類があります。 どの電車の運賃も円とスヌークのどちらの通貨でも支払え、i番目の電車の運賃は、 円で支払う場合ai円、 スヌークで払う場合biスヌークです。

両替所のある都市に行くと、1円を1スヌークに両替することができます。 ただし、 両替をするときには持っている円すべてをスヌークに両替しなければなりません。 つまり、kenkoooo さんの所持金がX円であるときに両替をすると、 kenkoooo さんの所持金はXスヌークになります。 現在、両替所はn個の都市すべてに存在しますが、i番目の都市の両替所は今年からi年後に閉鎖されてしまい、i年後とそれ以降は使うことができません。

kenkoooo さんは 10**15円を持って都市sから旅に出て、 都市tへ向かおうと思っています。 移動中、kenkoooo さんは両替所のある都市のいずれかで円をスヌークに両替しようと考えています。 ただし、都市sまたは都市tの両替所で両替をしてもよいものとします。

kenkoooo さんは移動の経路と両替をする都市を適切に選ぶことで、できるだけ多くのスヌークを持っている状態で 都市tに辿り着きたいと考えています。 i=0,...,n−1のそれぞれについて、i年後に都市 s から都市 t へ移動した際に kenkoooo さんが所持しているスヌークの最大額を求めてください。 ただし、旅行中に年をまたぐことは無いとします。

出典: D - Saving Snuuk

解法

両替を行う都市を  i として、都市  s から都市  i までの最小コスト(円)をダイクストラ法で求め、同様に都市  i から都市  t までの最小コスト(スヌーク)をダイクストラ法で求めます。

n, m, s, t = map(int, input().split())
y = 10**15

yen_g = Graph()
sno_g = Graph()
for _ in range(m):
    u, v, a, b = map(int, input().split())
    yen_g.add_edge(u, v, a)
    yen_g.add_edge(v, u, a)
    sno_g.add_edge(u, v, b)
    sno_g.add_edge(v, u, b)

cost = [0] * n
yen_d = Dijkstra(yen_g, s)
sno_d = Dijkstra(sno_g, t)
for i in range(1, n + 1):
    cost[i - 1] = y - (yen_d.dist[i] + sno_d.dist[i])

ans = [y] * n
ans[n - 1] = cost[n - 1]
for i in range(n - 2, -1, -1):
    ans[i] = max(ans[i + 1], cost[i])

for a in ans:
    print(a)

*1:Vはノード数、Eはエッジ数