Post

[BOJ] 특정 거리의 도시 찾기 - 18352 (S2)

[BOJ] 특정 거리의 도시 찾기 - 18352 (S2)
시간 제한메모리 제한
2 초256 MB

문제

어떤 나라에는 1번부터 N번까지의 도시와 M개의 단방향 도로가 존재한다. 모든 도로의 거리는 1이다.

특정한 도시 X로부터 출발하여 도달할 수 있는 모든 도시 중에서, 최단 거리가 정확히 K인 모든 도시들의 번호를 출력하는 프로그램을 작성하시오.

입력

첫째 줄에 도시의 개수 N, 도로의 개수 M, 거리 정보 K, 출발 도시의 번호 X가 주어진다. (2 ≤ N ≤ 300,000, 1 ≤ M ≤ 1,000,000, 1 ≤ K ≤ 300,000, 1 ≤ X ≤ N)

둘째 줄부터 M개의 줄에 걸쳐서 두 개의 자연수 A, B가 공백을 기준으로 구분되어 주어진다. 이는 A번 도시에서 B번 도시로 이동하는 단방향 도로가 존재한다는 의미다. (1 ≤ A, B ≤ N)

출력

X로부터 출발하여 도달할 수 있는 도시 중에서, 최단 거리가 K인 모든 도시의 번호를 한 줄에 하나씩 오름차순으로 출력한다.

이 때 도달할 수 있는 도시 중에서, 최단 거리가 K인 도시가 하나도 존재하지 않으면 -1을 출력한다.

풀이

단방향 그래프에서 특정 노드로부터의 최단 거리를 구하는 문제이다. 모든 간선의 가중치가 1이므로 BFS 또는 다익스트라로 해결할 수 있다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import heapq
from collections import defaultdict

INF = float("inf")

N, M, K, X = map(int, input().split())

graph = defaultdict(list)

for _ in range(M):
    a, b = map(int, input().split())
    graph[a].append(b)


def dijkstra(start):
    dist_list = [INF] * (N + 1)
    dist_list[start] = 0
    q = [(0, start)]

    while q:
        dist, cur_node = heapq.heappop(q)
        if dist_list[cur_node] < dist:
            continue

        for n_node in graph[cur_node]:
            if dist + 1 < dist_list[n_node]:
                dist_list[n_node] = dist + 1
                heapq.heappush(q, (dist + 1, n_node))

    return dist_list


inf_cnt = 0
for node, dist in enumerate(dijkstra(X)[1:], start=1):
    if dist == K:
        print(node)
    else:
        inf_cnt += 1

if inf_cnt == N:
    print(-1)

시간 복잡도

  • 다익스트라: O((N + M) log N)
  • N ≤ 300,000, M ≤ 1,000,000
  • 충분히 통과 가능

최적화 방법

모든 간선의 가중치가 1이므로, BFS를 사용하면 O(N + M)으로 더 빠르게 해결할 수 있다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from collections import deque

def bfs(start):
    dist_list = [-1] * (N + 1)
    dist_list[start] = 0
    q = deque([start])
    
    while q:
        cur_node = q.popleft()
        
        for n_node in graph[cur_node]:
            if dist_list[n_node] == -1:
                dist_list[n_node] = dist_list[cur_node] + 1
                q.append(n_node)
    
    return dist_list
This post is licensed under CC BY 4.0 by the author.