Skip to content

Commit

Permalink
batch made, starting reward
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed Apr 17, 2024
1 parent b47f6e0 commit ec97f1c
Showing 1 changed file with 80 additions and 81 deletions.
161 changes: 80 additions & 81 deletions dominoes/datasets/tsp_dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from copy import copy
from multiprocessing import Pool, cpu_count
import torch


from .base import DatasetRL
from .support import get_paths


class TSPDataset(DatasetRL):
"""A dataset for generating traveling salesman problem environments for training and evaluation"""

def __init__(self, **parameters):
def __init__(self, device="cpu", **parameters):
"""constructor method"""
self.set_device(device)

# check parameters
self._check_parameters(init=True, **parameters)
Expand All @@ -18,36 +20,6 @@ def __init__(self, **parameters):
self.prms = self._required_parameters()
self.prms = self.parameters(**parameters)

def _check_parameters(self, reference=None, init=False, **task_parameters):
"""
check if parameters provided in the task_parameters are valid (and complete)
checks two things:
1. If any parameters are provided that are not recognized for the task, an error will be generated
... if init=True, will additionally check:
2. If any parameters are required for the task but not provided, an error will be generated
args:
reference: dict, the reference parameters to check against (if not provided, uses self._required_parameters())
init: bool, whether this is being called by the constructor's __init__ method
in practive, this determines whether any required parameters without defaults are set properly
task_parameters: dict, the parameters provided at initialization
raise ValueError if any parameters are not recognized or required parameters are not provided
"""
if reference is None:
reference = self._required_parameters()
for param in task_parameters:
if param not in reference:
raise ValueError(f"parameter {param} not recognized for task {self.task}")
# if init==True, then this is being called by the constructor's __init__ method and
# we need to check if any required parameters without defaults are set properly
if init:
for param in reference:
if param not in task_parameters and reference[param] is None:
raise ValueError(f"parameter {param} not provided for task {self.task}")

def _required_parameters(self):
"""
return the required parameters for the task. This is hard-coded here and only here,
Expand All @@ -64,84 +36,111 @@ def _required_parameters(self):
"""
# base parameters for all tasks
params = dict(
hand_size=None, # this parameter is required to be set at initialization
num_cities=None, # this parameter is required to be set at initialization
coord_dims=2,
batch_size=1,
return_target=False,
ignore_index=-1,
threads=1,
)
return params

def parameters(self, **prms):
@torch.no_grad()
def generate_batch(self, device=None, **kwargs):
"""
Helper method for handling default parameters for each task
generates a batch of TSP environments with the specified parameters and additional outputs
The way this is designed is for there to be default parameters set at initialization,
which never change (unless you edit them directly), and then batch-specific parameters
that you can update for each batch. Here, the default parameters are copied then updated
by the optional kwargs for this function, then the updated parameters are returned
and used by whatever method called ``parameters``.
parallelized preparation of batch, better to use 1 thread if num_cities~<10 or batch_size<=256
For more details on possible inputs, look at "_required_parameters"
batch keys:
input: torch.Tensor, the input to the network, as a binary dominoe representation (and null token)
train: bool, whether the batch is for training or evaluation
selection: torch.Tensor, the selection of dominoes in the hand
target: torch.Tensor, the target for the network (only if requested)
"""
# get registered parameters
prms_to_use = copy(self.prms)
# check if updates are valid
self._check_parameters(reference=prms_to_use, init=False, **prms)
# update parameters
prms_to_use.update(prms)
# return to caller function
return prms_to_use
# get device
device = self.get_device(device)

@torch.no_grad()
def generate_batch(self, **kwargs):
"""
---- fill this in ----
"""
# get parameters with potential updates
prms = self.parameters(**kwargs)

# get a random dominoe hand
# will encode the hand as binary representations including null and available tokens if requested
# will also include the index of the selection in each hand a list of available values for each batch element
# note that dominoes direction is randomized for the input, but not for the target
input, selection, available = self._random_dominoe_hand(
prms["hand_size"],
self._randomize_direction(dominoes),
batch_size=prms["batch_size"],
null_token=self.null_token,
available_token=self.available_token,
)
# def tsp_batch(batch_size, num_cities, return_target=True, return_full=False, threads=1):
input = torch.rand((prms["batch_size"], prms["num_cities"], prms["coord_dims"]), dtype=torch.float)
dists = torch.cdist(input)

# define initial position as closest to origin (arbitrary but standard choice)
init_idx = torch.argmin(torch.sum(input**2, dim=2), dim=1)

# make a mask for the input
mask_tokens = prms["hand_size"] + (1 if self.null_token else 0) + (1 if self.available_token else 0)
mask = torch.ones((prms["batch_size"], mask_tokens), dtype=torch.float)
# get representation of initial position (will be fed to decoder)
init_input = torch.gather(input, 1, init_idx.view(-1, 1, 1).expand(-1, -1, prms["coord_dims"]))

# construct batch dictionary
batch = dict(input=input, mask=mask, train=train, selection=selection)
batch = dict(input=input.to(device), dists=dists, init_idx=init_idx, init_input=init_input)

# add task specific parameters to the batch dictionary
batch.update(prms)

# if target is requested (e.g. for SL tasks) then get target based on registered task
if prms["return_target"]:
# get target dictionary
target_dict = self.set_target(**prms)
# update batch dictionary with target dictionary
batch.update(target_dict)
batch["target"] = get_paths(input, dists, init_idx, prms["threads"]).to(device)

return batch

def set_target(self, **prms):
"""
--- fill this in ---
"""

@torch.no_grad()
def reward_function(self, choices, batch, **kwargs):
"""
--- fill this in ---
measure the reward acquired by the choices made by a set of networks for the current batch
rewards are 1 when a dominoe is chosen that:
- hasn't been played yet
- has less than or equal value to the last dominoe played (first dominoe always valid)
rewards are -1 when a dominoe is chosen that:
- has already been played
- has greater value than the last dominoe played
note: rewards are set to 0 after a mistake is made
args:
choice: torch.Tensor, index to the choices made by the network
batch: tuple, the batch of data generated for this training step
kwargs: not used, here for consistency with other dataset types
returns:
torch.Tensor, the rewards for the network
"""
pass
assert choices.ndim == 2, "choices should be a 2-d tensor of the sequence of choices for each batch element"
num_cities = batch["num_cities"]
batch_size = num_choices = choices.shape
device = choices.device

distance = torch.zeros((batch_size, num_choices)).to(device)
new_city = torch.ones((batch_size, num_choices)).to(device)

last_location = batch["init_idx"] # last (i.e. initial position) is final step of permutation of cities

last_location = copy(choices[:, 0]) # last (i.e. initial position) is final step of permutation of cities
src = torch.ones((batchSize, 1), dtype=torch.bool).to(device)
visited = torch.zeros((batchSize, numChoices), dtype=torch.bool).to(device)
visited.scatter_(1, last_location.view(batchSize, 1), src) # put first city in to the "visited" tensor
for nc in range(1, numChoices):
next_location = choices[:, nc]
c_dist_possible = torch.gather(dists, 1, last_location.view(batchSize, 1, 1).expand(-1, -1, numCities)).squeeze(1)
distance[:, nc] = torch.gather(c_dist_possible, 1, next_location.view(batchSize, 1)).squeeze(1)
c_visited = torch.gather(visited, 1, next_location.view(batchSize, 1)).squeeze(1)
visited.scatter_(1, next_location.view(batchSize, 1), src)
new_city[c_visited, nc] = -1.0
new_city[~c_visited, nc] = 1.0
last_location = copy(next_location) # update last location

# add return step (to initial city) to the final choice
c_dist_possible = torch.gather(dists, 1, choices[:, 0].view(batchSize, 1, 1).expand(-1, -1, numCities)).squeeze(1)
distance[:, -1] += torch.gather(c_dist_possible, 1, choices[:, -1].view(batchSize, 1)).squeeze(1)

return distance, new_city


# @torch.no_grad()
Expand Down

0 comments on commit ec97f1c

Please sign in to comment.