kumilog.net

データ分析やプログラミングの話などを書いています。

Union Find

はじめに

素集合データ構造を表すUnion FindアルゴリズムをPythonで実装して、例題を問いてみます。

追記

  • 例題に ABC 049 D を追加しました。(2018-07-28)
  • グループの要素数を数える機能と、例題に ABC 120 D を追加しました。(2019-04-28)

Union Find

Union Findは、すべての要素がいずれかのグループに属しており、ある要素が属するグループを管理するものです。

例えば、以下のような4つの要素があり、要素2と3は同じグループだが、要素2と要素4は別々のグループ、といったようなことを管理します。

- elem1: groupA
- elem2: groupB
- elem3: groupB
- elem4: groupC

これを、木を用いて

  • 2つのグループの結合
  • 要素が属するグループを探索

を行うアルゴリズムがUnion Findです。(グループの分割はできません。)

詳細については、以下のスライドがとても分かりやすいです。

www.slideshare.net

Pythonで実装

まずは素直に実装を行います。

初期値として、すべての要素が異なるグループ(要素番号=グループ番号)に属するようにします。すべての要素が木の根になります。

class UnionFind(object):
    def __init__(self, n=1):
        self.par = [i for i in range(n)]

再帰的に木の根を求めることで、ある要素が属するグループを探索します。

    def find(self, x):
        if self.par[x] == x:
            return x
        else:
            return self.find(self.par[x])

2つの要素が属するグループを結合します。xの子ノードとしてyを連結します。

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if x != y:
            self.par[y] = x

2つの要素が同じグループに属するか否かは、findを用いると簡単に求まります。

    def is_same(self, x, y):
        return self.find(x) == self.find(y)

経路圧縮

木の経路を圧縮することで計算量を削減することができます。

findで再帰的に探索するときに、根に直接つなぎます。

    def find(self, x):
        # ...    
            self.par[x] = self.find(self.par[x])
            return self.par[x]

グループを結合するときに、小さい木に結合するようにします。木の高さを表すrankは初期値0にしておきます。

    def union(self, x, y):
        # ...    
        if x != y:
            if self.rank[x] < self.rank[y]:
                x, y = y, x
            if self.rank[x] == self.rank[y]:
                self.rank[x] += 1
            self.par[y] = x

グループの要素数

ノードを結合するときに、そのノードが属するグループの要素数も結合させることで、グループの要素数を数えることができます。

    def __init__(self, n=1):
        # ...
        self.size = [1 for _ in range(n)]

    def union(self, x, y):
        # ...
        if x != y:
            # ...
            self.size[x] += self.size[y]

    def get_size(self, x):
        """
        x が属するグループの要素数
        """
        x = self.find(x)
        return self.size[x]

最終的なコード

まとめると以下のようなコードになります。

class UnionFind(object):
    def __init__(self, n=1):
        self.par = [i for i in range(n)]
        self.rank = [0 for _ in range(n)]
        self.size = [1 for _ in range(n)]

    def find(self, x):
        """
        x が属するグループを探索
        """
        if self.par[x] == x:
            return x
        else:
            self.par[x] = self.find(self.par[x])
            return self.par[x]

    def union(self, x, y):
        """
        x と y のグループを結合
        """
        x = self.find(x)
        y = self.find(y)
        if x != y:
            if self.rank[x] < self.rank[y]:
                x, y = y, x
            if self.rank[x] == self.rank[y]:
                self.rank[x] += 1
            self.par[y] = x
            self.size[x] += self.size[y]

    def is_same(self, x, y):
        """
        x と y が同じグループか否か
        """
        return self.find(x) == self.find(y)

    def get_size(self, x):
        """
        x が属するグループの要素数
        """
        x = self.find(x)
        return self.size[x]

例題

ATC 001 B - Union Find

問題

N 頂点の、単純とは限らない無向グラフを考えます。 初期状態では、頂点のみが存在し、辺は全く存在せず、全ての頂点が孤立しているとします。 以下の 2 種類のクエリが、Q 回与えられます。

  • 連結クエリ: 頂点 A と、頂点 B を繋ぐ辺を追加します。
  • 判定クエリ: 頂点 A と、頂点 B が、連結であるかどうか判定します。連結であれば Yes、そうでなければ No を出力します。

クエリを順番に処理し、判定クエリへの回答を出力して下さい。 この際、同じ辺が何度も追加されることや、自分自身への辺が追加されることもある事に注意してください。

連結であるとは、頂点 A から頂点 B まで辺をたどって到達可能であることを意味します。 A と B が同じ頂点の場合、連結であると定義します。 グラフは無向であるため、連結クエリによって頂点 A,B 間の辺が追加されると、A から B へも B から A へも辿れるようになります。

出典: B: Union Find - AtCoder Typical Contest 001 | AtCoder

解法

Union Findそのまま使うと解くことができます。

N, Q = map(int, input().split())

uf = UnionFind(N)
for _ in range(Q):
    P, A, B = map(int, input().split())
    # 連結クエリ
    if P == 0:
        uf.union(A, B)
    # 判定クエリ
    else:
        if uf.is_same(A, B):
            print('Yes')
        else:
            print('No')

ABC 049 D - 連結 / Connectivity

問題

 N 個の都市があり、 K 本の道路と  L 本の鉄道が都市の間に伸びています。  i 番目の道路は  p_i 番目と  q_i 番目の都市を双方向に結び、  i 番目の鉄道は  r_i 番目と  s_i 番目の都市を双方向に結びます。 異なる道路が同じ 2 つの都市を結ぶことはありません。同様に、異なる鉄道が同じ 2 つの都市を結ぶことはありません。

ある都市から別の都市に何本かの道路を通って到達できるとき、それらの都市は道路で連結しているとします。また、すべての都市はそれ自身と道路で連結しているとみなします。鉄道についても同様に定めます。

全ての都市について、その都市と道路・鉄道のどちらでも連結している都市の数を求めてください。

出典: D - 連結 / Connectivity

解法

ある都市  i から道路で連結している都市をUnionFindを用いて求めます。同じ根であるとき連結している都市となります。同様に、鉄道で連結している都市もUnionFindを利用して求めます。

道路・鉄道のどちらでも連結している都市の数は、根のペアが同じである数となるので、ペアのカウントを行います。

from collections import defaultdict
 
N, K, L = map(int, input().split())

# 道路のUnionFind
uf1 = UnionFind(N)
for i in range(K):
    p, q = map(int, input().split())
    uf1.union(p - 1, q - 1)
 
# 鉄道のUnionFind
uf2 = UnionFind(N)
for i in range(L):
    r, s = map(int, input().split())
    uf2.union(r - 1, s - 1)
 
# 道路と鉄道の連結の組み合わせ(根のペアのカウント)
c = defaultdict(int)
for i in range(N):
    c[(uf1.find(i), uf2.find(i))] += 1

# 回答を出力
for i in range(N):
    print(c[(uf1.find(i), uf2.find(i))])

ABC 120 D - Decayed Bridges

解法

N, M = map(int, input().split())
A, B = [None] * M, [None] * M
for i in range(M):
    A[i], B[i] = map(int, input().split())

uf = UnionFind(N)
ans = [None] * M
ans[M - 1] = N * (N - 1) // 2
for i in range(M - 1, 0, -1):
    x, y = A[i] - 1, B[i] - 1
    if uf.is_same(x, y):
        ans[i - 1] = ans[i]
    else:
        ans[i - 1] = ans[i] - (uf.get_size(x) * uf.get_size(y))
    uf.union(x, y)
for a in ans:
    print(a)