Skip to content

Commit

Permalink
finalize reward function handling
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed Apr 17, 2024
1 parent 957eaf6 commit 8533150
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions dominoes/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,21 @@ def _check_task(self, task):

def _check_parameters(self, reference=None, init=False, **task_parameters):
"""
check if the parameters provided at initialization are valid and complete
check if parameters provided in the task_parameters are valid (and complete)
checks two things:
1. If any parameters are provided that are not recognized for the task, an error will be generated
... if init=True, will additionally check:
2. If any parameters are required for the task but not provided, an error will be generated
args:
reference: dict, the reference parameters to check against (if not provided, uses self._required_parameters())
init: bool, whether this is being called by the constructor's __init__ method
in practive, this determines whether any required parameters without defaults are set properly
task_parameters: dict, the parameters provided at initialization
raise ValueError if any parameters are not recognized or required parameters are not provided
"""
if reference is None:
reference = self._required_parameters()
Expand Down Expand Up @@ -317,11 +324,17 @@ def reward_function(self, choices, batch, **kwargs):
args:
choice: torch.Tensor, index to the choices made by the network
batch: tuple, the batch of data generated for this training step
kwargs: optional kwargs for each task-specific reward function
returns:
torch.Tensor, the rewards for the network
"""
pass
if self.task == "sequencer":
return self._measurereward_sequencer(choices, batch, **kwargs)
elif self.task == "sorting":
return self._measurereward_sorter(choices, batch, **kwargs):
else:
raise ValueError(f"task {self.task} not recognized")

@torch.no_grad()
def _gettarget_sequencer(self, dominoes, selection, available, **prms):
Expand Down Expand Up @@ -391,7 +404,7 @@ def _gettarget_sorting(self, dominoes, selection, **prms):
return dict(target=target, value=value)

@torch.no_grad()
def _measurereward_sequencer(self, choices, batch, return_direction=False):
def _measurereward_sequencer(self, choices, batch, return_direction=False, verbose=None):
"""
reward function for sequencer
Expand Down Expand Up @@ -432,7 +445,6 @@ def _measurereward_sequencer(self, choices, batch, return_direction=False):
null_index = copy(num_in_hand)

# check verbose
verbose = batch.get("verbose", None)
if verbose is not None:
debug = True
assert 0 <= verbose < batch_size, "verbose should be an index corresponding to one of the batch elements"
Expand Down Expand Up @@ -554,7 +566,7 @@ def _measurereward_sequencer(self, choices, batch, return_direction=False):
return rewards

@torch.no_grad()
def _measurereward_sorter(self, choices, batch):
def _measurereward_sorter(self, choices, batch, **kwargs):
"""
measure the reward for the sorting task
Expand Down

0 comments on commit 8533150

Please sign in to comment.