From b95c9de63449fe38961a905add3979c059d5cd56 Mon Sep 17 00:00:00 2001 From: landoskape Date: Wed, 17 Apr 2024 20:43:12 +0100 Subject: [PATCH] support for TSP (now on torch) --- dominoes/datasets/support.py | 78 ++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/dominoes/datasets/support.py b/dominoes/datasets/support.py index 4ba0109..09f0964 100644 --- a/dominoes/datasets/support.py +++ b/dominoes/datasets/support.py @@ -246,7 +246,7 @@ def held_karp(dists): # Backtrack to find full path path = [] - for i in range(n - 1): + for _ in range(n - 1): path.append(parent) new_bits = bits & ~(1 << parent) _, parent = C[(bits, parent)] @@ -258,49 +258,49 @@ def held_karp(dists): return opt, list(reversed(path)) -def make_path(input): - pos, distmat = input - closest_to_origin = np.argmin(np.sum(pos**2, axis=1)) - dd = sp.spatial.distance.squareform(distmat) - _, cpath = held_karp(dd) - shift = {val: idx for idx, val in enumerate(cpath)}[closest_to_origin] - cpath = np.roll(cpath, -shift) - check_points = pos[cpath[[1, -1]]] # second point and last point - check for clockwise travel - angles = np.arctan(check_points[:, 1] / check_points[:, 0]) - if angles[1] > angles[0]: - cpath = np.flip(np.roll(cpath, -1)) - # finally, move it so the origin is the last location - return np.roll(cpath, -1) +def make_path(coordinates, distances, idx_init): + """ + for a set of coordinates, returns the shortest path that starts at the + initial index and ends closest to the origin, and is clockwise + args: + coordinates: (num_cities, 2) tensor of coordinates + distances: (num_cities, num_cities) tensor of distances between coordinates + idx_init: index of the initial city -def get_path(xy, dists): - """ - for batch of (batch, num_cities, 2), returns shortest path using - held-karp algorithm that ends closest to origin and is clockwise + returns: + best_path: (num_cities) tensor of the best path """ - return [make_path(input) for input in zip(xy, dists)] + # use held_karp algorithm to get fastest path through coordinates + best_path = torch.tensor(held_karp(distances)[1], dtype=torch.long) + # shift the path so it starts at the initial index + shift = {val.item(): idx for idx, val in enumerate(best_path)}[idx_init.item()] + best_path = torch.roll(best_path, -shift) -def get_path_pool(xy, dists, threads=8): - with Pool(threads) as p: - path = list(p.map(make_path, zip(xy, dists))) - return path + # make second point in path the second closest to origin + check_points = coordinates[best_path[[1, -1]]] + check_distance = torch.sum(check_points**2, dim=1) + # flip the path such that the second point is the second closest to the origin + if check_distance[1] < check_distance[0]: + best_path = torch.flip(torch.roll(best_path, -1), dims=(0,)) -def tsp_batch(batch_size, num_cities, return_target=True, return_full=False, threads=1): - """parallelized preparation of batch, better to use 1 thread if num_cities~<10 or batch_size<=256""" - xy = np.random.random((batch_size, num_cities, 2)) - dists = np.stack([sp.spatial.distance.pdist(p) for p in xy]) - input = torch.tensor(xy, dtype=torch.float) - if return_target: - if threads > 1: - target = torch.tensor(np.stack(get_path_pool(xy, dists, threads)), dtype=torch.long) - else: - target = torch.tensor(np.stack(get_path(xy, dists)), dtype=torch.long) - else: - target = None - if return_full: - torch_dists = torch.stack([torch.tensor(sp.spatial.distance.squareform(d)) for d in dists]) - return input, target, torch.tensor(xy), torch_dists + # finally, roll it once so the origin is the last location + best_path = torch.roll(best_path, -1) + + return best_path + + +def get_paths(coordinates, distances, idx_init, threads=1): + """ + for batch of (batch, num_cities, 2), returns shortest path using + held-karp algorithm that ends closest to origin and is clockwise + """ + if threads > 1: + with Pool(threads) as p: + path = list(p.starmap(make_path, zip(coordinates, distances, idx_init))) else: - return input, target + path = [make_path(coord, dist, idx) for coord, dist, idx in zip(coordinates, distances, idx_init)] + + return torch.stack(path).long()