Skip to content

Latest commit

 

History

History
166 lines (131 loc) · 5.03 KB

union_find.md

File metadata and controls

166 lines (131 loc) · 5.03 KB

并查集

用于处理不相交集合 (disjoint sets) 合并及查找的问题,典型应用有连通分量检测,环路检测等。原理和复杂度分析等可以参考维基百科

class Solution:
    def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:

        parent = list(range(len(edges) + 1))
        rank = [1] * (len(edges) + 1)
        
        def find(x):
            if parent[parent[x]] != parent[x]:
                parent[x] = find(parent[x]) # path compression
            return parent[x]
        
        def union(x, y):
            px, py = find(x), find(y)
            if px == py:
                return False
            # union by rank
            if rank[px] > rank[py]:
                parent[py] = px
            elif rank[px] < rank[py]:
                parent[px] = py
            else:
                parent[px] = py
                rank[py] += 1
            return True
        
        for edge in edges:
            if not union(edge[0], edge[1]):
                return edge
class Solution:
    def accountsMerge(self, accounts: List[List[str]]) -> List[List[str]]:
        
        parent = []
        rank = []
        
        def find(x):
            if parent[parent[x]] != parent[x]:
                parent[x] = find(parent[x])
            return parent[x]
        
        def union(x, y):
            px, py = find(x), find(y)
            if px == py:
                return
            if rank[px] > rank[py]:
                parent[py] = px
            elif rank[px] < rank[py]:
                parent[px] = py
            else:
                parent[px] = py
                rank[py] += 1
            return
        
        email2name = {}
        email2idx = {}
        i = 0
        for acc in accounts:
            for email in acc[1:]:
                email2name[email] = acc[0]
                if email not in email2idx:
                    parent.append(i)
                    rank.append(1)
                    email2idx[email] = i
                    i += 1
                union(email2idx[acc[1]], email2idx[email])
        
        result = collections.defaultdict(list)
        for email in email2name:
            result[find(email2idx[email])].append(email)
        
        return [[email2name[s[0]]] + sorted(s) for s in result.values()]
class Solution:
    def longestConsecutive(self, nums: List[int]) -> int:
        
        parent = {num: num for num in nums}
        length = {num: 1 for num in nums}
        
        def find(x):
            if parent[parent[x]] != parent[x]:
                parent[x] = find(parent[x])
            return parent[x]
        
        def union(x, y):
            px, py = find(x), find(y)
            if px == py:
                return
            # union by size
            if length[px] > length[py]:
                parent[py] = px
                length[px] += length[py]
            else:
                parent[px] = py
                length[py] += length[px]
            return
        
        max_length = 0
        for num in nums:
            if num + 1 in parent:
                union(num + 1, num)
            if num - 1 in parent:
                union(num - 1, num)
            
            max_length = max(max_length, length[parent[num]])
        
        return max_length

Kruskal's algorithm

地图上有 m 条无向边,每条边 (x, y, w) 表示位置 m 到位置 y 的权值为 w。从位置 0 到 位置 n 可能有多条路径。我们定义一条路径的危险值为这条路径中所有的边的最大权值。请问从位置 0 到 位置 n 所有路径中最小的危险值为多少?

  • 最小危险值为最小生成树中 0 到 n 路径上的最大边权。
# Kruskal's algorithm
class Solution:
    def getMinRiskValue(self, N, M, X, Y, W):
        
        # Kruskal's algorithm with union-find
        parent = list(range(N + 1))
        rank = [1] * (N + 1)
        
        def find(x):
            if parent[parent[x]] != parent[x]:
                parent[x] = find(parent[x])
            return parent[x]
        
        def union(x, y):
            px, py = find(x), find(y)
            if px == py:
                return False
            
            if rank[px] > rank[py]:
                parent[py] = px
            elif rank[px] < rank[py]:
                parent[px] = py
            else:
                parent[px] = py
                rank[py] += 1
            
            return True
        
        edges = sorted(zip(W, X, Y))
        
        for w, x, y in edges:
            if union(x, y) and find(0) == find(N): # early return without constructing MST
                return w