반응형

문제

어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 곱을 구하려 한다. 만약에 1, 2, 3, 4, 5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 곱을 구하라고 한다면 240을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 곱을 구하라고 한다면 48이 될 것이다.

입력

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 회수이고, K는 구간의 곱을 구하는 회수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+2번째 줄까지 세 개의 정수 a,b,c가 주어지는데, a가 1인 경우 b번째 수를 c (0 ≤ c ≤ 1,000,000)로 바꾸고 a가 2인 경우에는 b부터 c까지의 곱을 구하여 출력하면 된다.

출력

첫째 줄부터 K줄에 걸쳐 구한 구간의 곱을 1,000,000,007로 나눈 나머지를 출력한다.

예제 입력 1

5 2 2
1
2
3
4
5
1 3 6
2 2 5
1 5 2
2 3 5

예제 출력 1

240
48

예제 입력 2

5 2 2
1
2
3
4
5
1 3 0
2 2 5
1 3 6
2 2 5

예제 출력 2

0
240

나의 풀이

이 문제는 2042번 구간 합 구하기와 아주 비슷한 문제로 쉽게 풀 수 있었다. 우선 이 문제를 풀 때는 따로 init을 하지 않고 update를 가지고 배열에 있는 수들을 하나씩 세그먼트 트리에 추가해나갔다. 또 구간 합 구하기와 다르게 주의할 점이 있었다. 처음에 구간 합 처럼 diff값을 두어 바꾸려는 값 // 원래의 값으로 계산을 하려했지만 이러면 원래의 값이 더 큰 경우 0이라는 값이 나오게되어 문제를 풀 수가 없었다.


그래서 update의 구조를 살짝 바꾸어 문제를 해결하였는데, 기존 구간합에서는 idx가 범위내에만 있다면 diff를 모두 더해주었지만 이 문제에서는 새로운 값을 인자로 받아 start == end일 경우에 tree[node]에 새로운 값을 넣어주었고, 그렇지 않다면 update를 왼쪽, 오른쪽 재귀를 각각 돌려주어 값들을 최신화 시킨 뒤 tree[node] = tree[node*2] * tree[node*2+1]을 해주었다. 그리고 이 과정에서 1000000007로 나눈 나머지 값을 저장하여 계산하는데에 오버플로우 발생 및 연산속도를 빠르게 해주었다.


query도 구간 합과 비슷하지만 right와 left가 범위 안에 없을 때 구간 합 문제와 다르게 곱셈의 항등원인 1을 반환해주었고 , 반환값을 줄때 +대신 *을 사용했다. 그리고 여기서도 곱셈을 하면 1000000007이 넘을수도 있으므로 모듈라 연산을 통해 나머지값을 반환해주었다.


코드

# 11505번 구간 곱 구하기
import sys
from math import *

# update
def update(node, start, end, idx, val):
    if idx < start or end < idx:
        return

    if start == end:
        tree[node] = val
        return

    mid = (start + end) // 2
    update(node*2,start,mid,idx,val)
    update(node*2+1,mid+1,end,idx,val)
    tree[node] = (tree[node*2] * tree[node*2+1]) % 1000000007

# query
def query(node, start, end, left, right):
    if right < start or end < left:
        return 1

    if left <= start and end <= right:
        return tree[node]

    mid = (start + end) // 2
    tmp = query(node*2,start,mid,left,right) * query(node*2+1,mid+1,end,left,right)
    return tmp % 1000000007

# main
n, m, k = [int(x) for x in sys.stdin.readline().split()]

h = int(ceil(log2(n)))
t_size = 1 << (h+1)

arr = []
tree = [0] * t_size

for i in range(n):
    num = int(sys.stdin.readline())
    arr.append(num)
    update(1,0,n-1,i,num)

for _ in range(m+k):
    a, b, c = [int(x) for x in sys.stdin.readline().split()]

    if a == 1:
        arr[b-1] = c
        update(1,0,n-1,b-1,c)
    elif a == 2:
        print(query(1,0,n-1,b-1,c-1))
반응형

BELATED ARTICLES

more