Skip to content

Commit

Permalink
CB bug fix
Browse files Browse the repository at this point in the history
Summary: get scores should get batch.state

Reviewed By: alexnikulkov

Differential Revision: D55813029

fbshipit-source-id: ef912c4f7c789944eb26476b8e24cc6a214bd57d
  • Loading branch information
Yonathan Efroni authored and facebook-github-bot committed Apr 9, 2024
1 parent 80651b0 commit 9da5168
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def optimizer(self) -> torch.optim.Optimizer:
return self._optimizer

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:

# get scores for logging purpose
ucb_scores = self.get_scores(batch.state).mean()

if self._state_features_only:
input_features = batch.state
else:
Expand All @@ -151,9 +155,6 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
else torch.ones_like(expected_values)
)

# get scores for logging
ucb_scores = self.get_scores(input_features).mean()

# criterion = mae, mse, Xentropy
# Xentropy loss apply Sigmoid, MSE or MAE apply Identiy
criterion = LOSS_TYPES[self.loss_type]
Expand Down

0 comments on commit 9da5168

Please sign in to comment.