ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [Python] 백준 1197 : 최소 스패닝 트리
    코딩테스트 2024. 6. 17. 22:54

    문제 링크

     

    Kruskal 알고리즘

     

    Kruskal 알고리즘 자체에 대한 설명은 여기로.

     

    최소 스패닝 트리(MST) 문제에서, 코드로 구현하기 더 쉬운 것은 Prim 알고리즘이지만, 때때로 Kruskal 알고리즘이 필요할 때가 있다. Kruskal 알고리즘은 사실 사람 입장에서 더 쉬운 방법이라고 생각한다. 왜냐하면 "간선"을 고를 때, Cycle이 되는지 안 되는지 파악하는 코드가 꽤나 복잡하기 때문이다. 

     

    그렇기에 다음의 코드를 알아두면 좋다.

    #시작은, 모든 노드의 부모 노드는 자기 자신
    #0번 노드의 부모는 0, 1번 노드의 부모는 1...
    parent = list(range(V))
    
    #부모 노드를 찾는 코드
    #부모가 자기 자신이 아니라면 부모의 부모를 재귀적으로 찾는 것
    def find(parent, x):
        if parent[x] != x:
            parent[x] = find(parent, parent[x])
        return parent[x]
        
    #두 "간선그룹"이 합쳐질 때, 부모를 하나로 병합
    #이때 Rank(depth)가 낮은 부모로 합쳐지는데, 이는 시간복잡도를 낮추기 위함임. 
    def union(parent, rank, x, y):
        root_x = find(parent, x)
        root_y = find(parent, y)
    
        if root_x != root_y:
            if rank[root_x] > rank[root_y]:
                parent[root_y] = root_x
            elif rank[root_y] > rank[root_x]:
                parent[root_x] = root_y
            else:
                parent[root_y] = root_x
                rank[root_x] += 1

     

    정답 코드

    import heapq
    
    V, E = map(int, input().split())
    
    graph = []
    for _ in range(E) :
        s, e, w = map(int, input().split())
        heapq.heappush(graph, (w, s-1, e-1))
    
    parent = list(range(V))
    rank = [0] * V
    
    total_weight = 0
    edges_used = 0
    
    def find(parent, x) :
        if parent[x] != x :
            parent[x] = find(parent, parent[x])
        return parent[x]
    
    def union(parent, rank, x, y) :
        root_x = find(parent, x)
        root_y = find(parent, y)
    
        if root_x != root_y :
            if rank[root_x] > rank[root_y] :
                parent[root_y] = root_x
            elif rank[root_y] > rank[root_x] :
                parent[root_x] = root_y
            else :
                parent[root_y] = root_x
                rank[root_x] += 1
    
    while graph and edges_used < V-1 :
        cw, cs, ce = heapq.heappop(graph)
    
        root_cs = find(parent, cs)
        root_ce = find(parent, ce)
    
        if root_cs != root_ce :
            union(parent, rank, cs, ce)
            total_weight += cw
            edges_used += 1
    
    print(total_weight)
Designed by Tistory.