본문 바로가기
KraftonJungle2기/Today I Learned

[TIL] Algorithm - 최소 신장 트리 + 그래프 기본 문제 풀이

by SooooooooS 2023. 4. 21.
728x90

1. 그래프 기본 문제 풀이

1. 5639번 - 이진 검색 트리

import sys
sys.setrecursionlimit(10**6)

nums = []

# 입력이 더이상 들어오지 않을 때 오류 발생
# 오류가 나타나면 while문 종료
while True :
    try :
        nums.append(int(sys.stdin.readline()))
    except :
        break

# nums의 시작 인덱스 = start
# nums의 끝 인덱스  = end
def post(start, end) :
    if start == end : # start와 end가 같은 경우
        # 한가지 경우만 있으니 출력
        print(nums[start])
        return
    root = nums[start] # 현재 루트 노드
    mid = start
    for i in range(start+1, end+1) :
        if nums[i] > root : # 왼쪽 서브트리와 오른쪽 서브트리 구분을 위한 작업
            mid = i
            break
    if mid == start + 1 or mid == start: # 왼쪽 자식이 없거나 오른쪽 자식이 없으면 
        post(start+1, end)
    else :
        post(start+1, mid-1) # 왼쪽 서브 트리 실행
        post(mid, end) # 오른쪽 서브트리 실행
    print(root) # 마지막으로 루트 출력
        
post(0, len(nums)-1)

처음에는 어떻게 풀어야할 지 모르겠어서 입력 예제를 내가 알기 편한 배열로 놓아봤다.
그랬더니 위의 이미지와 같이 덩어리로 보면 전위 순회는 root → left → right 순으로 나타난다.
후위 순회는 left → right → root 순으로 나타난다.

그래서 배열의 start, end 인덱스를 주면서 현재 어느 범위까지 보는지 확인하고
각각 root, left, right로 나눠준다. 그 후 후위 순회의 순서로 print 해준다.

어렵게 생각했던것과 달리 예상외로 간단하게 문제가 풀렸다. 
직접 푸니까 역시 뿌듯하다.

2. 1197번 - 최소 스패닝 트리

1. 개념 정리

  • 최소신장 트리(minimum spanning tree, MST)
    • 그래프의 모든 정점을 포함하고
    • 정점 간 서로 연결이 되며
    • 순환이 존재하지 않는다.
    • 신장 트리에서의 정점 n개일 경우 간선 = n-1 개이다.
  • 프림 알고리즘(Prim'salgorithm)
    • https://ko.wikipedia.org/wiki/프림_알고리즘
    • 그래프 내의 간선이 많을 경우 유리 = O(ElogN)
    • 임의의 시작점에서 현재까지 연결된 정점들에서 연결되지 않은 정점들에 대해 가중치가 가장 작은 정점 연결

2. 프림 알고리즘으로 구현한 코드

import sys
import heapq
import collections

V, E = map(int, sys.stdin.readline().split())
matrix = collections.defaultdict(list) # 빈 리스트 생성
for _ in range(E) :
    x, y, weight = map(int, sys.stdin.readline().split())
    # 무방향 그래프 표현
    # 0번 부터 V-1번까지 노드라고 생각한다.(문제상에는 1 ~ V 까지)
    matrix[x-1].append([weight, x-1, y-1])
    matrix[y-1].append([weight, y-1, x-1])

def primMST(v) :
    node = [0] * V # 방문 여부를 저장할 배열
    # 시작 정점 방문
    node[v] = 1
    # 시작노드에서 갈 수 있는 모든 간선 리스트
    candidate = matrix[v] 
    heapq.heapify(candidate) # 우선순위 큐로 변환
    sum = 0
    while candidate :
        weight, x, y = heapq.heappop(candidate)
        # 정점 y를 방문한 적이 없다면
        if node[y] == 0 :
            node[y] = 1
            sum += weight
            
            # y에서 갈 수 있는 모든 간선에 대한 정보 중에서
            # 방문한 적이 없는 노드들만 candidate에 저장
            for edge in matrix[y] :
                if node[edge[2]] == 0 :
                    heapq.heappush(candidate, edge)
    print(sum)

primMST(0)
  • 크루스칼 알고리즘(Kruskal's algorithm)
    • https://ko.wikipedia.org/wiki/크러스컬_알고리즘
    • greedy 알고리즘의 일종으로 그래프
      간선들을 가중치의 오름차순으로 정렬한 후 순환을 형성하지 않는 선에서 정렬된 순서대로 간선 선택
    • 그래프 내의 간선이 적을 경우 유리 = O(ElogE)
    • Union & Find
      • disjoint set(서로소 집합)을 표현한 자료구조
        • 서로 중복되지 않는 부분 집합들로 나눠진 원소들에 대한 정보를 저장하고 조작하는 자료구조
      • 서로 다른 두 집합을 병합 = Union / 집합 원소가 어떤 집합에 속해있는지 찾는 = Find
      • 연산
        • make-set(x) : x를 유일한 원소로 하는 새로운 집합 생성
        • union(x,y) : x가 속한 집합과 y가 속한 집합을 합친다.
        • find(x) x가 속한 집합의 대표값을 반환

3. 크루스칼 알고리즘으로 구현한 코드

import sys

V, E = map(int, sys.stdin.readline().split())

matrix = [list(map(int, sys.stdin.readline().split())) for _ in range(E)]
matrix.sort(key= lambda x : x[2])

# 각 정점의 부모노드를 자기 자신으로 설정
# make_set() 동작이라고 생각하면된다.
disjoint = [i for i in range(V+1)]

# x의 부모노드 찾기
# 주의할 점은 부모노드가 자기 자신이 아니면 그 부모 노드의 부모 노드를 찾아가야 한다.!!
def find_set(x) :
    if disjoint[x] == x :
        return x
    else :
        return find_set(disjoint[x])

# x, y의 부모 노드를 통일시킨다.
# 단, 작은 값으로 통일
def union_set(x, y) :
    a = find_set(x)
    b = find_set(y)
    if a < b :
        disjoint[b] = a
    else :
        disjoint[a] = b
    
    
sum = 0
# 크루스칼 알고리즘 사용!
for x, y, weight in matrix :
    if find_set(x) != find_set(y) :
        union_set(x, y)
        sum += weight

print(sum)
최소 신장 트리... 분명히 배웠는데 막상 다시 공부해보니 너무 어려웠다. 
크루스칼 알고리즘은 이해하면서 직접 구현할 수 있었는데 프림 알고리즘은 좀 많이 시간을 들여 구현하려고 했지만
다른 블로그 코드를 참고하며 구현했다..ㅜㅠ

< 개념 참고 블로그 >

https://ongveloper.tistory.com/376

◆ 크루스칼 알고리즘 https://chanhuiseok.github.io/posts/algo-33/

◆ union-find https://gmlwjd9405.github.io/2018/08/31/algorithm-union-find.html

◆ 프림 알고리즘 https://4legs-study.tistory.com/112


3. 1260번 - DFS와 BFS

import sys
from collections import deque

N, M, V = map(int, sys.stdin.readline().split())
matrix = [[0]*N for _ in range(N)]

for _ in range(M) :
    x, y = map(int, sys.stdin.readline().split())
    matrix[x-1][y-1] = 1
    matrix[y-1][x-1] = 1

def dfs(v) :
    visited = [0] * N
    stack = []
    result = []
    
    stack.append(v)
    visited[v] = 1
    result.append(v+1)
    
    while stack :
        cur = stack[-1]
        flag = False
        for i in range(N) :
            if matrix[cur][i] == 1 and visited[i] == 0 :
                stack.append(i)
                visited[i] = 1
                result.append(i+1)
                flag = True
                break
        if not flag :
            stack.pop()
    print(*result)
    
def bfs(v) :
    visited = [0] * N
    queue = deque()
    result = []
    
    queue.append(v)
    visited[v] = 1
    result.append(v+1)
    
    while queue :
        cur = queue.popleft()

        for i in range(N) :
            if matrix[i][cur] == 1 and visited[i] == 0 :
                queue.append(i)
                visited[i] = 1
                result.append(i+1)
    print(*result)
    
dfs(V-1)
bfs(V-1)
이전에 풀어봤던 문제였다. 분명히 정글 오기 전까지 그래프를 열심히 공부해서 잘 알고 있었는데
한동안 다른 문제에만 집중하다 보니까 살짝 헷갈렸다. 꾸준히가 정말 중요하다.

4. 11724번 - 연결 요소의 개수

import sys
from collections import deque

N, M = map(int, sys.stdin.readline().split())

matrix = [[0] * N for _ in range(N)]
visited = [0] * N
count = 0

for _ in range(M) :
    x, y = map(int, sys.stdin.readline().split())
    # 양방향 그래프
    matrix[x-1][y-1] = 1
    matrix[y-1][x-1] = 1

# BFS로 탐색
def connection(v) :
    global count
    queue = deque()
    
    queue.append(v)
    visited[v] = 1
    
    while queue :
        cur = queue.popleft()

        for i in range(N) :
            if matrix[i][cur] == 1 and visited[i] == 0 :
                queue.append(i)
                visited[i] = 1
    # 시작 정점 v에서 연결된 모든 정점 탐색이 끝나면 coun++
    count += 1

for i in range(N) :
    if visited[i] == 0 :
        connection(i)

print(count)
이 문제 역시 지난 번에 풀었었다. 그때는 DFS로 구현했길래 BFS로 구현해 보았다.
(실은 이번에는 뭔가 BFS로 하면 더 나을 것 같았다....ㅎ)

5. 2606번 - 바이러스

import sys
from collections import deque

computer = int(sys.stdin.readline())
net = int(sys.stdin.readline())

network = [[0] * (computer+1) for _ in range(computer+1)]

for _ in range(net) :
    x, y = map(int, sys.stdin.readline().split())
    network[x][y] = 1
    network[y][x] = 1

def virus(v) :
    visited = [0] * (computer+1)
    queue = deque()
    count = 0
    
    visited[v] = 1
    queue.append(v)
    
    while queue :
        cur = queue.popleft()
        for i in range(computer+1) :
            if network[cur][i] == 1 and visited[i] == 0 :
                queue.append(i)
                visited[i] = 1
                count += 1
    return  count

print(virus(1))
너무나 당연하게 BFS로 푼 문제

6. 11725번 - 트리의 부모 찾기

import sys

N = int(sys.stdin.readline())

matrix = [[] for _ in range(N)]

# 연결 상태만 저장하는 방식
for _ in range(N-1) :
    x, y = map(int, sys.stdin.readline().split())
    matrix[x-1].append(y-1)
    matrix[y-1].append(x-1)

def dfs(v) :
    # 방문 여부 확인 겸 부모 노드의 번호 저장
    visited = [0] * N
    stack = []
    
    stack.append(v)
    visited[v] = 1
    
    while stack :
        flag = False
        cur = stack[-1] # 현재 노드
        # 현재 노드와 연결되어 있는 요소들의 리스트 반복
        for i in matrix[cur] : 
            if visited[i] == 0 :
                stack.append(i)
                visited[i] = cur+1 # 부모노드 저장
                flag = True # 아직 탐색할 노드 존재
                break
        if not flag : # 더이상 탐색할 노드가 없으면
            stack.pop()
    for i in visited[1:] :
        print(i)

dfs(0)
아무생각 없이 인접행렬을 사용하여 구현했는데 메모리 초과...
보니까 노드 N (2 ≤ N ≤ 100,000) 조건으로 인해 최대 10억개의 행렬을 사용하게 되었다.
그래서 이번에 사용한 방법은 각 노드마다 자신과 연결된 노드의 리스트를 갖는 것이다.
ex) 백준 예시 1
7
1 6
6 3
3 5
4 1
2 4
4 7
--------------------------------
matrix[1] = [6, 4]
matrix[2] = [4]
matrix[3] = [6, 5]
matrix[4] = [1, 2, 7]
matrix[5] = [3]
matrix[6] = [1, 3]
matrix[7] = [4]
--------------------------------
이와 같이 저장하면 메모리 사용량을 줄일 수 있다.

7. 2178번 - 미로 탐색

import sys
from collections import deque

N, M = map(int ,sys.stdin.readline().split())

miro = [list(map(int, sys.stdin.readline().strip())) for _ in range(N)]

def bfs(miro) :
    queue = deque()
    visited = [[0]*M for _ in range(N)]
    
    queue.append((0,0))
    visited[0][0] = 1
    
    x = [-1, 1, 0, 0]
    y = [0, 0, -1, 1]
    
    while queue :
        curX, curY = queue.popleft()
        for i in range(4) :
            nX, nY = curX + x[i], curY + y[i]
            if 0 <= nX < N and 0 <= nY < M :
                if visited[nX][nY] == 0 and miro[nX][nY] == 1 :
                    miro[nX][nY] = miro[curX][curY] + 1
                    visited[nX][nY] = 1
                    queue.append((nX, nY))
    print(miro[N-1][M-1])

bfs(miro)
728x90