From 8533150897358ab68fda33772febe6a06407dab7 Mon Sep 17 00:00:00 2001 From: landoskape Date: Wed, 17 Apr 2024 08:36:06 +0100 Subject: [PATCH] finalize reward function handling --- dominoes/datasets/base.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/dominoes/datasets/base.py b/dominoes/datasets/base.py index 5846b50..73e8265 100644 --- a/dominoes/datasets/base.py +++ b/dominoes/datasets/base.py @@ -147,14 +147,21 @@ def _check_task(self, task): def _check_parameters(self, reference=None, init=False, **task_parameters): """ - check if the parameters provided at initialization are valid and complete + check if parameters provided in the task_parameters are valid (and complete) checks two things: 1. If any parameters are provided that are not recognized for the task, an error will be generated + + ... if init=True, will additionally check: 2. If any parameters are required for the task but not provided, an error will be generated args: + reference: dict, the reference parameters to check against (if not provided, uses self._required_parameters()) + init: bool, whether this is being called by the constructor's __init__ method + in practive, this determines whether any required parameters without defaults are set properly task_parameters: dict, the parameters provided at initialization + + raise ValueError if any parameters are not recognized or required parameters are not provided """ if reference is None: reference = self._required_parameters() @@ -317,11 +324,17 @@ def reward_function(self, choices, batch, **kwargs): args: choice: torch.Tensor, index to the choices made by the network batch: tuple, the batch of data generated for this training step + kwargs: optional kwargs for each task-specific reward function returns: torch.Tensor, the rewards for the network """ - pass + if self.task == "sequencer": + return self._measurereward_sequencer(choices, batch, **kwargs) + elif self.task == "sorting": + return self._measurereward_sorter(choices, batch, **kwargs): + else: + raise ValueError(f"task {self.task} not recognized") @torch.no_grad() def _gettarget_sequencer(self, dominoes, selection, available, **prms): @@ -391,7 +404,7 @@ def _gettarget_sorting(self, dominoes, selection, **prms): return dict(target=target, value=value) @torch.no_grad() - def _measurereward_sequencer(self, choices, batch, return_direction=False): + def _measurereward_sequencer(self, choices, batch, return_direction=False, verbose=None): """ reward function for sequencer @@ -432,7 +445,6 @@ def _measurereward_sequencer(self, choices, batch, return_direction=False): null_index = copy(num_in_hand) # check verbose - verbose = batch.get("verbose", None) if verbose is not None: debug = True assert 0 <= verbose < batch_size, "verbose should be an index corresponding to one of the batch elements" @@ -554,7 +566,7 @@ def _measurereward_sequencer(self, choices, batch, return_direction=False): return rewards @torch.no_grad() - def _measurereward_sorter(self, choices, batch): + def _measurereward_sorter(self, choices, batch, **kwargs): """ measure the reward for the sorting task