반응형

문제

소설가인 김대전은 소설을 여러 장(chapter)으로 나누어 쓰는데, 각 장은 각각 다른 파일에 저장하곤 한다. 소설의 모든 장을 쓰고 나서는 각 장이 쓰여진 파일을 합쳐서 최종적으로 소설의 완성본이 들어있는 한 개의 파일을 만든다. 이 과정에서 두 개의 파일을 합쳐서 하나의 임시파일을 만들고, 이 임시파일이나 원래의 파일을 계속 두 개씩 합쳐서 소설의 여러 장들이 연속이 되도록 파일을 합쳐나가고, 최종적으로는 하나의 파일로 합친다. 두 개의 파일을 합칠 때 필요한 비용(시간 등)이 두 파일 크기의 합이라고 가정할 때, 최종적인 한 개의 파일을 완성하는데 필요한 비용의 총 합을 계산하시오.

예를 들어, C1, C2, C3, C4가 연속적인 네 개의 장을 수록하고 있는 파일이고, 파일 크기가 각각 40, 30, 30, 50 이라고 하자. 이 파일들을 합치는 과정에서, 먼저 C2와 C3를 합쳐서 임시파일 X1을 만든다. 이때 비용 60이 필요하다. 그 다음으로 C1과 X1을 합쳐 임시파일 X2를 만들면 비용 100이 필요하다. 최종적으로 X2와 C4를 합쳐 최종파일을 만들면 비용 150이 필요하다. 따라서, 최종의 한 파일을 만드는데 필요한 비용의 합은 60+100+150=310 이다. 다른 방법으로 파일을 합치면 비용을 줄일 수 있다. 먼저 C1과 C2를 합쳐 임시파일 Y1을 만들고, C3와 C4를 합쳐 임시파일 Y2를 만들고, 최종적으로 Y1과 Y2를 합쳐 최종파일을 만들 수 있다. 이때 필요한 총 비용은 70+80+150=300 이다.

소설의 각 장들이 수록되어 있는 파일의 크기가 주어졌을 때, 이 파일들을 하나의 파일로 합칠 때 필요한 최소비용을 계산하는 프로그램을 작성하시오.

입력

프로그램은 표준 입력에서 입력 데이터를 받는다. 프로그램의 입력은 T개의 테스트 데이터로 이루어져 있는데, T는 입력의 맨 첫 줄에 주어진다.각 테스트 데이터는 두 개의 행으로 주어지는데, 첫 행에는 소설을 구성하는 장의 수를 나타내는 양의 정수 K (3 ≤ K ≤ 500)가 주어진다. 두 번째 행에는 1장부터 K장까지 수록한 파일의 크기를 나타내는 양의 정수 K개가 주어진다. 파일의 크기는 10,000을 초과하지 않는다.

출력

프로그램은 표준 출력에 출력한다. 각 테스트 데이터마다 정확히 한 행에 출력하는데, 모든 장을 합치는데 필요한 최소비용을 출력한다.

예제 입력 1

2
4
40 30 30 50
15
1 21 3 4 5 35 5 4 3 5 98 21 14 17 32

예제 출력 1

300
864

나의 풀이

소설의 페이지를 합치는 과정이다보니 인접한 페이지들만 붙일 수 있다는 것을 캐치해야한다. 그리고나서 psum으로 각각 인접의 페이지들을 더해나가는 배열을 만들어놓고 j 부터 i까지의 합들을 구해나가면 된다.


여기서 k라는 값으로 for문을 더 돌리게 되는데, 동적프로그래밍으로 그래프에서 최소거리를 구하는 문제에서처럼 k를 거쳐 가는 경우들을 모두 확인하여 최솟값을 찾는 것이었다.

따라서 점화식은 dp[j][i] = min(dp[j][i], dp[j][k] + dp[k+1][i])가 된다. 그리고 각 dp[j][i]에서 인접한 값을 더한값인 psum[i] - psum[j-1]값을 더해주면 된다.

예를들어 i가 4일때 j는 3부터 시작해 1까지 for문을 돌고 각 j에서 k가 j부터 i-1까지 돈다. i가 4이고 j가 2인 경우를 보면 2번 파일부터 4번 파일까지의 최솟값을 찾아야되는거고 그걸 찾기위해 k가 2부터 3까지 dp[2][4] = min(dp[2][4], dp[2][2] + dp[3][4])dp[2][4] = min(dp[2][4], dp[2][3] + dp[2][4])와 같이 최솟값을 유지해나간다.


아직 다이나믹 프로그래밍에 대한 공부를 더 많이 해야겠다고 느끼는 문제였다. 지금 이 풀이도 시간복잡도가 O(n^3)인데 더 줄여 O(n^2)까지 줄일 수 있다고 한다. 이것도 공부해봐야할 것 같다.


코드

# 11066번 파일 합치기
import sys

# main
t = int(sys.stdin.readline())

for _ in range(t):
    n = int(sys.stdin.readline())
    parr = [int(x) for x in sys.stdin.readline().split()]

    psum = [0] * (n+1)
    for i in range(1,n+1):
        psum[i] = psum[i-1] + parr[i-1]

    dp = [[0 for _ in range(n+1)] for _ in range(n+1)]

    for i in range(2,n+1):
        j = i - 1
        while j > 0:
            dp[j][i] = 100000000
            k = j
            while k < i:
                dp[j][i] = min(dp[j][i], dp[j][k] + dp[k+1][i])
                k += 1

            dp[j][i] += psum[i] - psum[j-1]
            j -= 1
    print(dp[1][n])
반응형