Skip to content

Commit

Permalink
support for TSP (now on torch)
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed Apr 17, 2024
1 parent ec97f1c commit b95c9de
Showing 1 changed file with 39 additions and 39 deletions.
78 changes: 39 additions & 39 deletions dominoes/datasets/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def held_karp(dists):

# Backtrack to find full path
path = []
for i in range(n - 1):
for _ in range(n - 1):
path.append(parent)
new_bits = bits & ~(1 << parent)
_, parent = C[(bits, parent)]
Expand All @@ -258,49 +258,49 @@ def held_karp(dists):
return opt, list(reversed(path))


def make_path(input):
pos, distmat = input
closest_to_origin = np.argmin(np.sum(pos**2, axis=1))
dd = sp.spatial.distance.squareform(distmat)
_, cpath = held_karp(dd)
shift = {val: idx for idx, val in enumerate(cpath)}[closest_to_origin]
cpath = np.roll(cpath, -shift)
check_points = pos[cpath[[1, -1]]] # second point and last point - check for clockwise travel
angles = np.arctan(check_points[:, 1] / check_points[:, 0])
if angles[1] > angles[0]:
cpath = np.flip(np.roll(cpath, -1))
# finally, move it so the origin is the last location
return np.roll(cpath, -1)
def make_path(coordinates, distances, idx_init):
"""
for a set of coordinates, returns the shortest path that starts at the
initial index and ends closest to the origin, and is clockwise
args:
coordinates: (num_cities, 2) tensor of coordinates
distances: (num_cities, num_cities) tensor of distances between coordinates
idx_init: index of the initial city
def get_path(xy, dists):
"""
for batch of (batch, num_cities, 2), returns shortest path using
held-karp algorithm that ends closest to origin and is clockwise
returns:
best_path: (num_cities) tensor of the best path
"""
return [make_path(input) for input in zip(xy, dists)]
# use held_karp algorithm to get fastest path through coordinates
best_path = torch.tensor(held_karp(distances)[1], dtype=torch.long)

# shift the path so it starts at the initial index
shift = {val.item(): idx for idx, val in enumerate(best_path)}[idx_init.item()]
best_path = torch.roll(best_path, -shift)

def get_path_pool(xy, dists, threads=8):
with Pool(threads) as p:
path = list(p.map(make_path, zip(xy, dists)))
return path
# make second point in path the second closest to origin
check_points = coordinates[best_path[[1, -1]]]
check_distance = torch.sum(check_points**2, dim=1)

# flip the path such that the second point is the second closest to the origin
if check_distance[1] < check_distance[0]:
best_path = torch.flip(torch.roll(best_path, -1), dims=(0,))

def tsp_batch(batch_size, num_cities, return_target=True, return_full=False, threads=1):
"""parallelized preparation of batch, better to use 1 thread if num_cities~<10 or batch_size<=256"""
xy = np.random.random((batch_size, num_cities, 2))
dists = np.stack([sp.spatial.distance.pdist(p) for p in xy])
input = torch.tensor(xy, dtype=torch.float)
if return_target:
if threads > 1:
target = torch.tensor(np.stack(get_path_pool(xy, dists, threads)), dtype=torch.long)
else:
target = torch.tensor(np.stack(get_path(xy, dists)), dtype=torch.long)
else:
target = None
if return_full:
torch_dists = torch.stack([torch.tensor(sp.spatial.distance.squareform(d)) for d in dists])
return input, target, torch.tensor(xy), torch_dists
# finally, roll it once so the origin is the last location
best_path = torch.roll(best_path, -1)

return best_path


def get_paths(coordinates, distances, idx_init, threads=1):
"""
for batch of (batch, num_cities, 2), returns shortest path using
held-karp algorithm that ends closest to origin and is clockwise
"""
if threads > 1:
with Pool(threads) as p:
path = list(p.starmap(make_path, zip(coordinates, distances, idx_init)))
else:
return input, target
path = [make_path(coord, dist, idx) for coord, dist, idx in zip(coordinates, distances, idx_init)]

return torch.stack(path).long()

0 comments on commit b95c9de

Please sign in to comment.