From b05536bf57808d9e93cc4c9de1192f3fb6e29de2 Mon Sep 17 00:00:00 2001 From: landoskape Date: Wed, 17 Apr 2024 08:46:15 +0100 Subject: [PATCH] adjusting process reward method --- dominoes/datasets/base.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/dominoes/datasets/base.py b/dominoes/datasets/base.py index 4a4c8f9..f15ad6c 100644 --- a/dominoes/datasets/base.py +++ b/dominoes/datasets/base.py @@ -61,29 +61,29 @@ def create_gamma_transform(self, max_output, gamma, device=None): # return the gamma transform matrix return (gamma**exponent).to(device) - def measure_reward(self, scores, choices, batch, gamma_transform): + def process_reward(self, rewards, scores, choices, gamma_transform): """ - measure the reward for a batch of choices + process the reward for performing policy gradient args: + rewards: list of torch.Tensor, the rewards for each network (precomputed using `reward_function`) scores: list of torch.Tensor, the log scores for the choices for each network choices: list of torch.Tensor, index to the choices made by each network - batch: tuple, the batch of data generated for this training step gamma_transform: torch.Tensor, the gamma transform matrix for the reward returns: list of torch.Tensor, the rewards for each network """ - # measure reward for each network - rewards = [self.reward_function(choice, batch) for choice in choices] - # measure cumulative discounted rewards for each network G = [torch.matmul(reward, gamma_transform) for reward in rewards] + # measure choice score for each network (the log-probability for each choice) + choice_scores = [torch.gather(score, 2, choice.unsqueeze(2)).squeeze(2) for score, choice in zip(scores, choices)] + # measure J for each network - J = [-torch.sum(score * g) for score, g in zip(scores, G)] + J = [-torch.sum(cs * g) for cs, g in zip(choice_scores, G)] - return rewards, G, J + return G, J @abstractmethod def reward_function(self, choices, batch): @@ -334,7 +334,7 @@ def reward_function(self, choices, batch, **kwargs): elif self.task == "sorting": return self._measurereward_sorter(choices, batch, **kwargs) else: - raise ValueError(f"task {self.task} not recognized") + raise ValueError(f"task {self.task} not recognized!") @torch.no_grad() def _gettarget_sequencer(self, dominoes, selection, available, **prms):