Skip to content

Commit

Permalink
Support weighted loss in Pearl CB
Browse files Browse the repository at this point in the history
Summary: Support weights in NN CB loss computation

Reviewed By: BerenLuthien, zxpmirror1994

Differential Revision: D53206841

fbshipit-source-id: 61ac050942b3546c9c94972dabc9997fc73c5d90
  • Loading branch information
Alex Nikulkov authored and facebook-github-bot committed Jan 31, 2024
1 parent a55a3c8 commit 10a83b1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
1 change: 0 additions & 1 deletion pearl/policy_learners/contextual_bandits/linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
else torch.ones_like(expected_values)
)
x = torch.cat([batch.state, batch.action], dim=1)
assert batch.weight is not None
self.model.learn_batch(
x=x,
y=batch.reward,
Expand Down
10 changes: 9 additions & 1 deletion pearl/policy_learners/contextual_bandits/neural_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,15 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
predicted_values = self.model(input_features)

criterion = LOSS_TYPES[self.loss_type]
loss = criterion(predicted_values.view(expected_values.shape), expected_values)

# don't reduce the loss, so that we can calculate weighted loss
loss = criterion(
predicted_values.view(expected_values.shape),
expected_values,
reduction="none",
)
assert loss.shape == batch_weight.shape
loss = (loss * batch_weight).sum() / batch_weight.sum() # weighted average loss

# Backward pass + optimizer step
self.optimizer.zero_grad()
Expand Down
10 changes: 8 additions & 2 deletions pearl/policy_learners/contextual_bandits/neural_linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,14 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
assert torch.all(predicted_values >= 0) and torch.all(predicted_values <= 1)
assert isinstance(self.model.output_activation, torch.nn.Sigmoid)

# TODO: handle weight in NN training by computing weighted loss
loss = criterion(predicted_values.view(expected_values.shape), expected_values)
# don't reduce the loss, so that we can calculate weighted loss
loss = criterion(
predicted_values.view(expected_values.shape),
expected_values,
reduction="none",
)
assert loss.shape == batch_weight.shape
loss = (loss * batch_weight).sum() / batch_weight.sum() # weighted average loss

# Optimize the NN via backpropagation
self._optimizer.zero_grad()
Expand Down

0 comments on commit 10a83b1

Please sign in to comment.