使用 Python 实现并查集

2018.10.22

并查集主要实现两个功能:

  1. 检测两个元素是否在同一个集合;
  2. 合并两个元素所在集合;

本次实现使用的优化方法为路径压缩,路径压缩的主要原理就是,在合并过程中将每个集合头和其他集合元素之间的距离调整为1

# -*- coding=utf-8 -*-


class UnionFind():
    """并查集实现
    """

    def __init__(self):
        self._fatherMap = dict()
        self._sizeMap = dict()

    def makeSets(self, nodes):
        """利用可迭代对象初始化并查集

        Args:
            nodes (iterable): 元素集合
        """

        self._fatherMap.clear()
        self._sizeMap.clear()
        for node in nodes:
            self._fatherMap[node] = node
            self._sizeMap[node] = 1

    def findHead(self, node):
        """找集合头,其中有一个对长链的集合的平整操作,
        平整操作执行的频率其实还需要进一步规整

        Raises:
            KeyError: 如果本身传入的 node 就是不在全集里的就直接抛错
        """

        if not self.isValidNode(node):
            raise KeyError('node not in set')

        father = self._fatherMap.get(node)
        if father != node:
            father = self.findHead(father)
        self._fatherMap[node] = father
        return father

    def findHeadNotIter(self, node):
        """上面函数的非递归版
        """

        if not self.isValidNode(node):
            raise KeyError('node not in set')

        stack = []
        father = self._fatherMap.get(node)
        while father != node:
            stack.append(node)
        while stack:
            temp = stack.pop()
            self._sizeMap[self._fatherMap[temp]] -= 1
            self._fatherMap[temp] = father
            self._sizeMap[father] += 1
        return father

    def isSameSet(self, n1, n2):
        """检测两个元素是否在一个集合里

        Args:
            n1 (nodeType): 元素
            n2 (nodeType): 元素

        Returns:
            bool: True or False
        """

        return self.findHead(n1) == self.findHead(n2)

    def isValidNode(self, node):
        return node in self._fatherMap

    def union(self, n1, n2):
        """合并两个元素所在集合,小的集合并入到大的集合里,目的还是不想让集合链路过长
        """

        if not n1 or not n2:
            raise KeyError('node not valid')

        if not self.isValidNode(n1) or not self.isValidNode(n2):
            raise KeyError('node not valid')

        head1 = self.findHead(n1)
        head2 = self.findHead(n2)
        if head1 != head2:
            size1 = self._sizeMap[head1]
            size2 = self._sizeMap[head2]
            if size1 <= size2:
                self._sizeMap[self._fatherMap[head1]] -= 1
                self._fatherMap[head1] = head2
                self._sizeMap[head2] += size1
            else:
                self._sizeMap[self._fatherMap[head2]] -= 1
                self._fatherMap[head2] = head1
                self._sizeMap[head1] += size2


def main():
    sets = [1, 2, 3, 4, 5, 6]
    ex = UnionFind()
    ex.makeSets(sets)
    print(ex._fatherMap)
    print(ex._sizeMap)
    ex.union(1, 2)
    ex.union(3, 4)
    print(ex._fatherMap)
    print(ex._sizeMap)
    ex.union(5, 1)
    print(ex._fatherMap)
    print(ex._sizeMap)


if __name__ == "__main__":
    main()