はじめに
素集合データ構造を表すUnion FindアルゴリズムをPythonで実装して、例題を問いてみます。
追記
- 例題に ABC 049 D を追加しました。(2018-07-28)
- グループの要素数を数える機能と、例題に ABC 120 D を追加しました。(2019-04-28)
Union Find
Union Findは、すべての要素がいずれかのグループに属しており、ある要素が属するグループを管理するものです。
例えば、以下のような4つの要素があり、要素2と3は同じグループだが、要素2と要素4は別々のグループ、といったようなことを管理します。
- elem1: groupA - elem2: groupB - elem3: groupB - elem4: groupC
これを、木を用いて
- 2つのグループの結合
- 要素が属するグループを探索
を行うアルゴリズムがUnion Findです。(グループの分割はできません。)
詳細については、以下のスライドがとても分かりやすいです。
www.slideshare.net
Pythonで実装
まずは素直に実装を行います。
初期値として、すべての要素が異なるグループ(要素番号=グループ番号)に属するようにします。すべての要素が木の根になります。
class UnionFind(object): def __init__(self, n=1): self.par = [i for i in range(n)]
再帰的に木の根を求めることで、ある要素が属するグループを探索します。
def find(self, x): if self.par[x] == x: return x else: return self.find(self.par[x])
2つの要素が属するグループを結合します。x
の子ノードとしてy
を連結します。
def union(self, x, y): x = self.find(x) y = self.find(y) if x != y: self.par[y] = x
2つの要素が同じグループに属するか否かは、find
を用いると簡単に求まります。
def is_same(self, x, y): return self.find(x) == self.find(y)
経路圧縮
木の経路を圧縮することで計算量を削減することができます。
find
で再帰的に探索するときに、根に直接つなぎます。
def find(self, x): # ... self.par[x] = self.find(self.par[x]) return self.par[x]
グループを結合するときに、小さい木に結合するようにします。木の高さを表すrank
は初期値0にしておきます。
def union(self, x, y): # ... if x != y: if self.rank[x] < self.rank[y]: x, y = y, x if self.rank[x] == self.rank[y]: self.rank[x] += 1 self.par[y] = x
グループの要素数
ノードを結合するときに、そのノードが属するグループの要素数も結合させることで、グループの要素数を数えることができます。
def __init__(self, n=1): # ... self.size = [1 for _ in range(n)] def union(self, x, y): # ... if x != y: # ... self.size[x] += self.size[y] def get_size(self, x): """ x が属するグループの要素数 """ x = self.find(x) return self.size[x]
最終的なコード
まとめると以下のようなコードになります。
class UnionFind(object): def __init__(self, n=1): self.par = [i for i in range(n)] self.rank = [0 for _ in range(n)] self.size = [1 for _ in range(n)] def find(self, x): """ x が属するグループを探索 """ if self.par[x] == x: return x else: self.par[x] = self.find(self.par[x]) return self.par[x] def union(self, x, y): """ x と y のグループを結合 """ x = self.find(x) y = self.find(y) if x != y: if self.rank[x] < self.rank[y]: x, y = y, x if self.rank[x] == self.rank[y]: self.rank[x] += 1 self.par[y] = x self.size[x] += self.size[y] def is_same(self, x, y): """ x と y が同じグループか否か """ return self.find(x) == self.find(y) def get_size(self, x): """ x が属するグループの要素数 """ x = self.find(x) return self.size[x]
例題
ATC 001 B - Union Find
問題
N 頂点の、単純とは限らない無向グラフを考えます。 初期状態では、頂点のみが存在し、辺は全く存在せず、全ての頂点が孤立しているとします。 以下の 2 種類のクエリが、Q 回与えられます。
- 連結クエリ: 頂点 A と、頂点 B を繋ぐ辺を追加します。
- 判定クエリ: 頂点 A と、頂点 B が、連結であるかどうか判定します。連結であれば
Yes
、そうでなければNo
を出力します。クエリを順番に処理し、判定クエリへの回答を出力して下さい。 この際、同じ辺が何度も追加されることや、自分自身への辺が追加されることもある事に注意してください。
連結であるとは、頂点 A から頂点 B まで辺をたどって到達可能であることを意味します。 A と B が同じ頂点の場合、連結であると定義します。 グラフは無向であるため、連結クエリによって頂点 A,B 間の辺が追加されると、A から B へも B から A へも辿れるようになります。
出典: B: Union Find - AtCoder Typical Contest 001 | AtCoder
解法
Union Findそのまま使うと解くことができます。
N, Q = map(int, input().split()) uf = UnionFind(N) for _ in range(Q): P, A, B = map(int, input().split()) # 連結クエリ if P == 0: uf.union(A, B) # 判定クエリ else: if uf.is_same(A, B): print('Yes') else: print('No')
ABC 049 D - 連結 / Connectivity
問題
個の都市があり、 本の道路と 本の鉄道が都市の間に伸びています。 番目の道路は 番目と 番目の都市を双方向に結び、 番目の鉄道は 番目と 番目の都市を双方向に結びます。 異なる道路が同じ 2 つの都市を結ぶことはありません。同様に、異なる鉄道が同じ 2 つの都市を結ぶことはありません。
ある都市から別の都市に何本かの道路を通って到達できるとき、それらの都市は道路で連結しているとします。また、すべての都市はそれ自身と道路で連結しているとみなします。鉄道についても同様に定めます。
全ての都市について、その都市と道路・鉄道のどちらでも連結している都市の数を求めてください。
解法
ある都市 から道路で連結している都市をUnionFindを用いて求めます。同じ根であるとき連結している都市となります。同様に、鉄道で連結している都市もUnionFindを利用して求めます。
道路・鉄道のどちらでも連結している都市の数は、根のペアが同じである数となるので、ペアのカウントを行います。
from collections import defaultdict N, K, L = map(int, input().split()) # 道路のUnionFind uf1 = UnionFind(N) for i in range(K): p, q = map(int, input().split()) uf1.union(p - 1, q - 1) # 鉄道のUnionFind uf2 = UnionFind(N) for i in range(L): r, s = map(int, input().split()) uf2.union(r - 1, s - 1) # 道路と鉄道の連結の組み合わせ(根のペアのカウント) c = defaultdict(int) for i in range(N): c[(uf1.find(i), uf2.find(i))] += 1 # 回答を出力 for i in range(N): print(c[(uf1.find(i), uf2.find(i))])
ABC 120 D - Decayed Bridges
解法
N, M = map(int, input().split()) A, B = [None] * M, [None] * M for i in range(M): A[i], B[i] = map(int, input().split()) uf = UnionFind(N) ans = [None] * M ans[M - 1] = N * (N - 1) // 2 for i in range(M - 1, 0, -1): x, y = A[i] - 1, B[i] - 1 if uf.is_same(x, y): ans[i - 1] = ans[i] else: ans[i - 1] = ans[i] - (uf.get_size(x) * uf.get_size(y)) uf.union(x, y) for a in ans: print(a)