Post

[BOJ] 병합 정렬 1 - 24060 (S3)

[BOJ] 병합 정렬 1 - 24060 (S3)
시간 제한메모리 제한
1 초512 MB

문제

오늘도 서준이는 병합 정렬 수업 조교를 하고 있다. 아빠가 수업한 내용을 학생들이 잘 이해했는지 문제를 통해서 확인해보자.

N개의 서로 다른 양의 정수가 저장된 배열 A가 있다. 병합 정렬로 배열 A를 오름차순 정렬할 경우 배열 A에 K 번째 저장되는 수를 구해서 우리 서준이를 도와주자.

크기가 N인 배열에 대한 병합 정렬 의사 코드는 다음과 같다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
merge_sort(A[p..r]) { # A[p..r]을 오름차순 정렬한다.
    if (p < r) then {
        q <- ⌊(p + r) / 2⌋;       # q는 p, r의 중간 지점
        merge_sort(A, p, q);      # 전반부 정렬
        merge_sort(A, q + 1, r);  # 후반부 정렬
        merge(A, p, q, r);        # 병합
    }
}

# A[p..q]와 A[q+1..r]을 병합하여 A[p..r]을 오름차순 정렬된 상태로 만든다.
# A[p..q]와 A[q+1..r]은 이미 오름차순으로 정렬되어 있다.
merge(A[], p, q, r) {
    i <- p; j <- q + 1; t <- 1;
    while (i ≤ q and j ≤ r) {
        if (A[i] ≤ A[j])
        then tmp[t++] <- A[i++]; # tmp[t] <- A[i]; t++; i++;
        else tmp[t++] <- A[j++]; # tmp[t] <- A[j]; t++; j++;
    }
    while (i ≤ q)  # 왼쪽 배열 부분이 남은 경우
        tmp[t++] <- A[i++];
    while (j ≤ r)  # 오른쪽 배열 부분이 남은 경우
        tmp[t++] <- A[j++];
    i <- p; t <- 1;
    while (i ≤ r)  # 결과를 A[p..r]에 저장
        A[i++] <- tmp[t++];
}

입력

첫째 줄에 배열 A의 크기 N(5 ≤ N ≤ 500,000), 저장 횟수 K(1 ≤ K ≤ 108)가 주어진다.

다음 줄에 서로 다른 배열 A의 원소 A1, A2, …, AN이 주어진다. (1 ≤ Ai ≤ 109)

출력

배열 A에 K 번째 저장 되는 수를 출력한다. 저장 횟수가 K 보다 작으면 -1을 출력한다.

풀이

문제에 작성된 의사 코드를 따라 병합 정렬 코드를 완성하여 정렬된 배열 A에 K번째 저장된 수를 출력하는 문제이다.

병합 정렬

병합 정렬은 주어진 배열을 절반으로 나누는 과정과 이를 다시 합치는 병합 과정으로 진행된다.

먼저 부분 배열(나누어진 배열)의 원소가 하나만 남거나 부분 배열이 빌 때까지 배열을 절반으로 나누어 준다. 이는 재귀를 이용하여 다음과 같이 구현할 수 있다.

1
2
3
4
5
6
7
8
9
def merge_sort(a, p, r):
    global result
    if result != -1:
        return
    if p < r:
        q = (p + r) // 2
        merge_sort(a, p, q)
        merge_sort(a, q+1, r)
				...

모두 다 나눴으면 이제 병합하며 정렬된 배열을 완성한다.

병합할 두 배열의 원소들을 확인하며 작은 순으로 배열에 넣어준다. 하지만 두 배열의 길이가 항상 같지는 않으므로 남은 배열의 원소들도 차례대로 넣어준다. 이 때 각 부분 배열들은 이전의 병합에서 이미 정렬된 상태이므로 그냥 순서대로 넣어주면 된다.

위 방식대로 모든 병합을 완료하면 정렬이 완료된다.

하지만 우리가 원하는 것은 특정 번째에 저장된 수를 찾는 것이므로 정렬한 값을 임시로 저장한 딕셔너리 tmp의 값을 배열 a에 옮겨주면서 원하는 번째에 a에 저장된 값을 찾아낸다.

1
2
3
4
5
6
7
8
9
10
11
		...
		while i <= r:
        a[i] = tmp[t]
        cnt += 1
        if cnt == k:
            cnt = -float('inf')
            result = tmp[t]
            return
        t += 1
        i += 1
		...

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
a, k = map(int, input().split())
nums = list(map(int, input().split()))
cnt = 0
result = -1
def merge_sort(a, p, r):
    global result
    if result != -1:
        return
    if p < r:
        q = (p + r) // 2
        merge_sort(a, p, q)
        merge_sort(a, q+1, r)
        merge(a, p, q, r)

def merge(a, p, q, r):
    global cnt, result
    i, j, t = p, q + 1, 0
    tmp = {}
    while i <= q and j <= r:
        if a[i] <= a[j]:
            tmp[t] = a[i]
            t += 1
            i += 1
        else: 
            tmp[t] = a[j]
            t += 1
            j += 1
    while i <= q:
        tmp[t] = a[i]
        t += 1
        i += 1
    while j <= r:
        tmp[t] = a[j]
        t += 1
        j += 1
    i, t = p, 0
    while i <= r:
        a[i] = tmp[t]
        cnt += 1
        if cnt == k:
            cnt = -float('inf')
            result = tmp[t]
            return
        t += 1
        i += 1

merge_sort(nums, 0, a-1)
print(result)
This post is licensed under CC BY 4.0 by the author.