반응형

문제

외판원 순회 문제는 영어로 Traveling Salesman problem (TSP) 라고 불리는 문제로 computer science 분야에서 가장 중요하게 취급되는 문제 중 하나이다. 여러 가지 변종 문제가 있으나, 여기서는 가장 일반적인 형태의 문제를 살펴보자.

1번부터 N번까지 번호가 매겨져 있는 도시들이 있고, 도시들 사이에는 길이 있다. (길이 없을 수도 있다) 이제 한 외판원이 어느 한 도시에서 출발해 N개의 도시를 모두 거쳐 다시 원래의 도시로 돌아오는 순회 여행 경로를 계획하려고 한다. 단, 한 번 갔던 도시로는 다시 갈 수 없다. (맨 마지막에 여행을 출발했던 도시로 돌아오는 것은 예외) 이런 여행 경로는 여러 가지가 있을 수 있는데, 가장 적은 비용을 들이는 여행 계획을 세우고자 한다.

각 도시간에 이동하는데 드는 비용은 행렬 W[i][j]형태로 주어진다. W[i][j]는 도시 i에서 도시 j로 가기 위한 비용을 나타낸다. 비용은 대칭적이지 않다. 즉, W[i][j] 는 W[j][i]와 다를 수 있다. 모든 도시간의 비용은 양의 정수이다. W[i][i]는 항상 0이다. 경우에 따라서 도시 i에서 도시 j로 갈 수 없는 경우도 있으며 이럴 경우 W[i][j]=0이라고 하자.

N과 비용 행렬이 주어졌을 때, 가장 적은 비용을 들이는 외판원의 순회 여행 경로를 구하는 프로그램을 작성하시오.

입력

첫째 줄에 도시의 수 N이 주어진다. (2 ≤ N ≤ 16) 다음 N개의 줄에는 비용 행렬이 주어진다. 각 행렬의 성분은 1,000,000 이하의 양의 정수이며, 갈 수 없는 경우는 0이 주어진다. W[i][j]는 도시 i에서 j로 가기 위한 비용을 나타낸다.

항상 순회할 수 있는 경우만 입력으로 주어진다.

출력

첫째 줄에 외판원의 순회에 필요한 최소 비용을 출력한다.

예제 입력 1

4
0 10 15 20
5 0 9 10
6 13 0 12
8 8 9 0

예제 출력 1

35

나의 풀이

이 문제는 DP와 비트마스트를 같이 사용하여 해결하는 문제였다. 비트마스크를 사용해 해당 노드에 방문을 했었는지의 여부를 판단하며 DP배열에는 각 최소비용을 메모제이션해주어야한다. 그리고 어디에서 시작하던지 순회를 하는 것이기때문에 최소비용은 모두 같다. 따라서 0번에서 시작을 해도된다.

문제를 풀기위해서 재귀를 이용하면되며 모든 노드를 방문(visited == (1<<n)-1)이라면 해당 노드에서 0번으로 갈 수 있는지를 확인하고 갈 수없다면 INF를 반환해준다. 또 각 DP배열에서 최솟값을 메모제이션하기때문에 만약 해당 DP의 원소가 이미 최솟값이 정해졌다면 그 값을 반환해주면 된다. 이 메모제이션 부분이 시간을 많이 줄여줄 수 있다.

해당 최소비용을 구할때는 현재 노드에서 접근 가능한 모든 노드를 돌면서 방문하지 않았고 0(연결 되지 않은 곳)이 아닌지를 확인하고 아니라면 현재 노드에서 접근 가능한 노드까지의 비용 + 재귀로 접근 가능한 노드부터의 최솟값의 최솟값을 구해주고 이를 DP배열에 메모제이션한다.


코드

# 외판원 순회
import sys

def tsp(cur, visited):
    if visited == (1 << n) - 1:
        if not graph[cur][0] == 0:
            return graph[cur][0]
        else:
            return sys.maxsize

    if not mask[cur][visited] == -1:
        return mask[cur][visited]

    cost = sys.maxsize
    for i in range(n):
        if not visited & (1 << i) == 0:
            continue
        if graph[cur][i] == 0:
            continue
        cost = min(cost, tsp(i, visited | (1 << i)) + graph[cur][i])
    mask[cur][visited] = cost
    return cost

n = int(sys.stdin.readline())
graph = [list(map(int, input().split())) for _ in range(n)]
mask = [[-1] * (1 << n) for _ in range(n)]

print(tsp(0,1))
반응형

BELATED ARTICLES

more