并查集主要实现两个功能:
- 检测两个元素是否在同一个集合;
- 合并两个元素所在集合;
本次实现使用的优化方法为路径压缩,路径压缩的主要原理就是,在合并过程中将每个集合头和其他集合元素之间的距离调整为1
。
# -*- coding=utf-8 -*-
class UnionFind():
"""并查集实现
"""
def __init__(self):
self._fatherMap = dict()
self._sizeMap = dict()
def makeSets(self, nodes):
"""利用可迭代对象初始化并查集
Args:
nodes (iterable): 元素集合
"""
self._fatherMap.clear()
self._sizeMap.clear()
for node in nodes:
self._fatherMap[node] = node
self._sizeMap[node] = 1
def findHead(self, node):
"""找集合头,其中有一个对长链的集合的平整操作,
平整操作执行的频率其实还需要进一步规整
Raises:
KeyError: 如果本身传入的 node 就是不在全集里的就直接抛错
"""
if not self.isValidNode(node):
raise KeyError('node not in set')
father = self._fatherMap.get(node)
if father != node:
father = self.findHead(father)
self._fatherMap[node] = father
return father
def findHeadNotIter(self, node):
"""上面函数的非递归版
"""
if not self.isValidNode(node):
raise KeyError('node not in set')
stack = []
father = self._fatherMap.get(node)
while father != node:
stack.append(node)
while stack:
temp = stack.pop()
self._sizeMap[self._fatherMap[temp]] -= 1
self._fatherMap[temp] = father
self._sizeMap[father] += 1
return father
def isSameSet(self, n1, n2):
"""检测两个元素是否在一个集合里
Args:
n1 (nodeType): 元素
n2 (nodeType): 元素
Returns:
bool: True or False
"""
return self.findHead(n1) == self.findHead(n2)
def isValidNode(self, node):
return node in self._fatherMap
def union(self, n1, n2):
"""合并两个元素所在集合,小的集合并入到大的集合里,目的还是不想让集合链路过长
"""
if not n1 or not n2:
raise KeyError('node not valid')
if not self.isValidNode(n1) or not self.isValidNode(n2):
raise KeyError('node not valid')
head1 = self.findHead(n1)
head2 = self.findHead(n2)
if head1 != head2:
size1 = self._sizeMap[head1]
size2 = self._sizeMap[head2]
if size1 <= size2:
self._sizeMap[self._fatherMap[head1]] -= 1
self._fatherMap[head1] = head2
self._sizeMap[head2] += size1
else:
self._sizeMap[self._fatherMap[head2]] -= 1
self._fatherMap[head2] = head1
self._sizeMap[head1] += size2
def main():
sets = [1, 2, 3, 4, 5, 6]
ex = UnionFind()
ex.makeSets(sets)
print(ex._fatherMap)
print(ex._sizeMap)
ex.union(1, 2)
ex.union(3, 4)
print(ex._fatherMap)
print(ex._sizeMap)
ex.union(5, 1)
print(ex._fatherMap)
print(ex._sizeMap)
if __name__ == "__main__":
main()