반응형

문제

모두 알다시피 동호는 커피숍의 마담이다. (마담이 무엇인지는 본인에게 물어보도록 하자.)

어느 날 커피숍의 손님 A씨가 동호에게 게임을 하자고 했다.

그 게임은 다음과 같은 규칙을 갖는다.

N개의 정수가 있으면, 동호는 다음과 같이 말한다. “37번째 수의 합은 무엇이죠?” 그러면 상대방은 “그 답은 000입니다. 그리고 8번째 수를 2로 고치도록 하죠” 그러면 동호는 “네 알겠습니다.”라고 한 뒤에 다시 상대방이 동호가 했던 것처럼 “89번째 수의 합은 무엇이죠?”라고 묻게된다. 이 것을 번갈아 가면서 반복하는 게임이다.

당신은 이 게임의 심판 역을 맡았다. 요컨대, 질문에 대한 답들을 미리 알아야 한다는 것이다.

당신의 머리가 출중하다면 10만개 가량 되는 정수와 10만턴 정도 되는 게임을 기억할 수 있을 것이다. 몇판 이 게임을 즐기던 동호는 많은 사람들이 이 게임을 하기를 바라게 되었고, 당신에게 심판 프로그램을 구현해달라고 요청했다.

입력

첫째 줄에 수의 개수 N과 턴의 개수 Q가 주어진다.(1 ≤ N, Q ≤ 100,000) 둘째 줄에는 처음 배열에 들어가 있는 정수 N개가 주어진다. 세 번째 줄에서 Q+2번째 줄까지는 x y a b의 형식으로 x~y까지의 합을 구하여라, a번째 수를 b로 바꾸어라 라는 뜻의 데이터가 주어진다.

입력되는 모든 수는 32비트 부호있는 정수이다.

출력

한 턴마다 구한 합을 한 줄마다 한 개씩 출력한다.

예제 입력 1

5 2
1 2 3 4 5
2 3 3 1
3 5 4 1

예제 출력 1

5
10

노트

x~y는 당연히 x번째 부터 y번째가 맞다. 하지만, 이 문제에서는 x > y인 경우 y번째 부터 x번째이다.


나의 풀이

이 문제는 세그먼트 트리로 구간합을 구하여 반환해주고 a번째 수를 b로 업데이트를 해주면된다. 주의할 점은 맨아래 노트에 있는 사항이다. 필자는 처음에 저 노트 부분을 보지않고 풀어서 틀렸었다. 그리고 update를 하는 과정에서 전역에 선언되어있는 arr에 접근이 안되는지 diff로 넘겨주지 않으니 세그먼트트리에서 값의 변화가 이루어지지 않았었다. 이 문제때문에 많이 헤맸다...


세그먼트 트리를 이용해 구간합을 구하는 과정과 업데이트를 하는 과정은 이미 7578번 공장문제에서 다루었다. 이 문제도 이와 같이 구간 합을 구하면 된다. 다른 점은 이 문제에서는 arr을 한번에 받기위해서 init을 만들고 구간합을 다 구하는 과정을 삽입했다. 코드를 보면 쉽게 이해할 수 있을 것이다. 그러고 update과정을 설명하면 우선 바꾸려는 idx가 start와 end사이에 있지 않으면 return, 아닌 경우에 tree[node]에 diff = val - arr[idx] 즉, 원래 값을 빼고 새로운 값을 넣을때의 차이값을 더해준다. 그러고 start와 end가 같지 않다면 재귀로 계속 돌려준다.


코드

# 1275번 커피숍2
from math import *
import sys

# init 
def init(node, start, end):
    if start == end:
        tree[node] = arr[start]
        return tree[node]

    mid = (start + end) // 2
    tree[node] = init(node*2,start,mid) + init(node*2+1,mid+1,end)
    return tree[node]

# update
def update(node, start, end, idx, diff):
    if idx < start or end < idx:
        return

    tree[node] += diff

    if not start == end:
        mid = (start + end) // 2
        update(node*2,start,mid,idx,diff)
        update(node*2+1,mid+1,end,idx,diff)


# query
def query(node, start, end, left, right):
    if right < start or end < left:
        return 0

    if left <= start and end <= right:
        return tree[node]

    mid = (start + end) // 2
    return query(node*2,start,mid,left,right) + query(node*2+1,mid+1,end,left,right)

# main
n, q = [int(x) for x in sys.stdin.readline().split()]
arr = [int(x) for x in sys.stdin.readline().split()]

# segment_tree
h = int(ceil(log2(n)))
t_size = 1 << (h+1)

tree = [0] * t_size
init(1,0,n-1)

for _ in range(q):
    x, y, a, b = [int(num) for num in sys.stdin.readline().split()]

    if x <= y:
        print(query(1,0,n-1,x-1,y-1))
    else:
        print(query(1,0,n-1,y-1,x-1))

    diff = b - arr[a-1] 
    arr[a-1] = b
    update(1,0,n-1,a-1,diff)
반응형

BELATED ARTICLES

more