Skip to content

Commit

Permalink
adjusting process reward method
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed Apr 17, 2024
1 parent 4e9ed50 commit b05536b
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions dominoes/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,29 @@ def create_gamma_transform(self, max_output, gamma, device=None):
# return the gamma transform matrix
return (gamma**exponent).to(device)

def measure_reward(self, scores, choices, batch, gamma_transform):
def process_reward(self, rewards, scores, choices, gamma_transform):
"""
measure the reward for a batch of choices
process the reward for performing policy gradient
args:
rewards: list of torch.Tensor, the rewards for each network (precomputed using `reward_function`)
scores: list of torch.Tensor, the log scores for the choices for each network
choices: list of torch.Tensor, index to the choices made by each network
batch: tuple, the batch of data generated for this training step
gamma_transform: torch.Tensor, the gamma transform matrix for the reward
returns:
list of torch.Tensor, the rewards for each network
"""
# measure reward for each network
rewards = [self.reward_function(choice, batch) for choice in choices]

# measure cumulative discounted rewards for each network
G = [torch.matmul(reward, gamma_transform) for reward in rewards]

# measure choice score for each network (the log-probability for each choice)
choice_scores = [torch.gather(score, 2, choice.unsqueeze(2)).squeeze(2) for score, choice in zip(scores, choices)]

# measure J for each network
J = [-torch.sum(score * g) for score, g in zip(scores, G)]
J = [-torch.sum(cs * g) for cs, g in zip(choice_scores, G)]

return rewards, G, J
return G, J

@abstractmethod
def reward_function(self, choices, batch):
Expand Down Expand Up @@ -334,7 +334,7 @@ def reward_function(self, choices, batch, **kwargs):
elif self.task == "sorting":
return self._measurereward_sorter(choices, batch, **kwargs)
else:
raise ValueError(f"task {self.task} not recognized")
raise ValueError(f"task {self.task} not recognized!")

@torch.no_grad()
def _gettarget_sequencer(self, dominoes, selection, available, **prms):
Expand Down

0 comments on commit b05536b

Please sign in to comment.