반응형

문제

어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.

입력

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 회수이고, K는 구간의 합을 구하는 회수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b번째 수를 c로 바꾸고 a가 2인 경우에는 b번째 수부터 c번째 수까지의 합을 구하여 출력하면 된다.

a가 1인 경우 c는 long long 범위를 넘지 않는다.

출력

첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 long long 범위를 넘지 않는다.

예제 입력 1

5 2 2
1
2
3
4
5
1 3 6
2 2 5
1 5 2
2 3 5

예제 출력 1

17
12

나의 풀이

세그먼트 트리를 가지고 푸는 문제로 n을 입력받고 n에 대한 세그먼트 트리의 최대 크기(포화 완전 이진 트리)를 구해준다. 우선 크기를 구하기 위해서는 0~n범위의 구간값들을 담기 위한 트리의 높이는 h = int(ceil(log2(n)))이다. 그리고 이 높이 h에 대한 포화 완전 이진 트리의 모든 노드의 개수는 2^(h+1) - 1으로 파이썬 비트연산자를 이용해 1 << (h+1) - 1로 나타낼 수 있다. 이렇게 세그먼트 트리를 만들고 입력받은 배열에 대한 구간합들을 init()함수를 만들어 모두 구해준다.


  • init(node, start, end)

    이 함수는 앞서 말했듯이, 입력받은 배열의 나눠진 구간들의 합을 구해 트리를 초기화 시켜놓는 함수이다. 처음에는 start와 end를 각각 0과 n-1로 입력을 받고 mid = (start+end) // 2의 값을 가지고 2분할 해나가야한다. init(node*2,start,mid) 구간을 반으로 나누고 왼쪽, 따라서 트리에서 왼쪽 자식(node*2)으로 움직이고, init(node*2+1,mid+1,end) 구간을 반으로 나누고 오른쪽, 즉 트리에서 오른쪽 자식(node*2+1)로 움직여 계산을 한다. 재귀를 멈추는 조건은 start == end이며 이게 같다는 말은 즉, 구간이 없는 그냥 하나의 값으로 arr[start]를 넣어준다. 그리고 return으로 이 값을 반환해주고 재귀를 돌리며 왼쪽과 오른쪽의 합을 부모 노드에 저장을 해주면 각각 구간의 합들을 구할 수 있다.

    def init(node, start, end):
    
        if start == end:
            tree[node] = arr[start]
            return tree[node]
    
        mid = (start + end) // 2
    
        tree[node] = init(node*2, start, mid) + init(node*2+1, mid+1, end)
        return tree[node]

다음으로는 이렇게 세그먼트 트리를 초기화 하였다면, 특정 구간의 합을 구해 반환해주는 함수가 필요할텐데 그것이 바로 query()함수이다.

  • query(node, start, end, left, right)

    left와 right로 구하고자하는 구간의 시작과 끝을 입력받는다. query를 해줄때 중요한 점은 각 구간만의 값이 반환된다는게 아니라는 점이다. 다시말해 위에서 세그먼트 트리에서 반으로 쪼개어진 특정한 구간들만 주어지는게 아니라는 것이다. 가령, 범위가 04인데 반으로 쪼개면 02, 34인 구간으로 쪼개어진다. 그런데 만약 14의 범위를 query로 날린다면? 02의 구간에서 또 나눈뒤 1의 값과 34의 값을 합친값을 반환해주어야 한다는 것이다. 따라서 우선, left가 end보다 크거나, right가 start보다 작다면 left와 right가 이 구간과 일치지 않은 것이니 바로 return을 해주고, 만약, start와 end가 left와 right사이에 있다면, (예를들면 start,end가 각각 05이고 left,right가 03일 경우처럼) 바로 그 노드의 값을 반환해준다. 그렇지않다면 mid값을 구해 재귀로 왼쪽과 오른쪽 모두에게 날려 반환되는 값을 더해준다.

    def query(node, start, end, left, right):
        if left > end or start > right:
            return 0
    
        if left <= start and end <= right:
            return tree[node]
    
        mid = (start + end) // 2
        return query(node*2, start, mid, left, right) + query(node*2+1, mid+1, end, left, right)

그리고 이 문제에서 가장 중요하다고 생각했던 update()함수이다. 이 함수는 잘 생각해보니 힙 정렬과 원리가 같았다. 가령, 구간이 04이고 2번 인덱스의 값을 바꾼다면 04구간의 노드, 0~2구간의 노드, 2구간의 노드 이 세 가지 노드만 바꾸어주면 된다는 점이다. 힙에서 보면 어떠한 값이 현재 노드의 값보다 작으면 왼쪽, 크면 오른쪽으로 가듯이 말이다.

  • update(node, start, end, idx, diff)

    여기서도 재귀를 돌려야하기 때문에 node, start, end값을 받는다 그리고 업데이트를 해줄 인덱스 값인 idx를 받아주고 바꾸기 전의 원래 값을 빼고 새로 바꾸려는 값을 더해준 diff = val - arr[idx]값을 인자로 받는다. 우선 기본로직은 재귀를 돌려 힙정렬처 가야하므로 idx가 start보다 작거나, end보다 크면 return을 한다. 그렇지않다면 tree[node]의 값에 diff값을 더해준다. 여기서 이렇게 해주면 영향을 받는 값들에 모두 원래의 값만큼 빼주고 새로운 값만큼 더해주게 되므로 모든 경우가 업데이트 되게된다. 그리고나서 start와 end가 같으면 끝나면 되지만, 다르다면 위와 같이 왼쪽과 오른쪽을 재귀를 돌려주면 된다.

    def update(node, i, diff, start, end):
        if not (start <= i and i <= end):
            return
    
        tree[node] += diff
    
        if start == end:
            return
        elif start != end:
            mid = (start + end) // 2
            update(node*2, i, diff, start, mid)
            update(node*2+1, i, diff, mid+1, end)

이렇게 문제를 해결할 수 있다. 그런데 필자가 했던 실수 중 하나는 이거다. 입력으로 b번째와 같이 주어진다. 그런데 b번째라는 말은 arr에서는 b-1인덱스를 뜻하므로 query의 인자를 넘겨줄때 b-1의 인덱스 값을 넘겨주어야 한다는 점을 생각해야한다.


코드

# 2042번 구간 합 구하기
from math import *

# seg_tree init
def init(node, start, end):

    if start == end:
        tree[node] = arr[start]
        return tree[node]

    mid = (start + end) // 2

    tree[node] = init(node*2, start, mid) + init(node*2+1, mid+1, end)
    return tree[node]

# 수의 변경
def update(node, i, diff, start, end):
    if not (start <= i and i <= end):
        return

    tree[node] += diff

    if start == end:
        return
    elif start != end:
        mid = (start + end) // 2
        update(node*2, i, diff, start, mid)
        update(node*2+1, i, diff, mid+1, end)

# 구간의 합 쿼리
def query(node, start, end, left, right):
    if left > end or start > right:
        return 0

    if left <= start and end <= right:
        return tree[node]

    mid = (start + end) // 2
    return query(node*2, start, mid, left, right) + query(node*2+1, mid+1, end, left, right)

# main
n, m, k = map(int,input().split())

h = int(ceil(log2(n)))
t_size = 1 << (h+1)

arr = []
tree = [0] * (t_size)

for _ in range(n):
    arr.append(int(input()))

init(1,0,n-1)

for _ in range(m+k):
    a, b, c = map(int,input().split())

    if a == 1:
        # update
        diff = c - arr[b-1]
        arr[b-1] = c
        update(1, b-1, diff, 0, n-1)

    elif a == 2:
        # query
        print(query(1, 0, n-1, b-1, c-1))
반응형

BELATED ARTICLES

more