Skip to content

Commit

Permalink
Avoid NeuralLinearBandit failure if zero-weight batch
Browse files Browse the repository at this point in the history
Summary:
This diff addresses an issue with the NeuralLinearBandit policy learner in Pearl. The issue is that the policy learner can fail if the batch weight sums to zero. (For the specific div0 see [link](https://www.internalfb.com/code/fbsource/[266f3e6083d66d91a742991b0bfceadc148082a6]/fbcode/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py?lines=184)) To circumvent this issue, I've modified the NeuralLinearBandit policy learner to avoid the optimization path altogether in this case.

While in principle one might hope that zero-weight batches are never provided to the bandit, it is a relatively common problem encountered with small batches + DisjointBanditContainer, at least for this current production model -> https://www.internalfb.com/mlhub/pipelines/runs/fblearner/556370776 The result is a training flow that regularly borks out for hyperparameter tuning.

The resulting NaNs actually error in the `pinv` (svd) method within the LinearUCB component of `NeuralLinearBandit` ("the input matrix contained non-finite values"). It would be helpful if the error is thrown immediately, but because they are not, the error has taken more time to debug.

Reviewed By: alexnikulkov

Differential Revision: D65428925

fbshipit-source-id: 61e77345293508b4732e34d33ab6989097d00538
  • Loading branch information
Alex Bird authored and facebook-github-bot committed Nov 5, 2024
1 parent ececba7 commit f01b97e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 27 deletions.
65 changes: 39 additions & 26 deletions pearl/policy_learners/contextual_bandits/neural_linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,34 +166,47 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
else torch.ones_like(expected_values)
)

# criterion = mae, mse, Xentropy
# Xentropy loss apply Sigmoid, MSE or MAE apply Identiy
criterion = LOSS_TYPES[self.loss_type]
if self.loss_type == "cross_entropy":
assert torch.all(expected_values >= 0) and torch.all(expected_values <= 1)
assert torch.all(predicted_values >= 0) and torch.all(predicted_values <= 1)
assert isinstance(self.model.output_activation, torch.nn.Sigmoid)
if batch_weight.sum().item() == 0:
# if all weights are zero, then there's nothing to learn, but also a
# division by zero. So, short circuit, and avoid the optimizer.
loss = torch.tensor(0.0)
else:
# criterion = mae, mse, Xentropy
# Xentropy loss apply Sigmoid, MSE or MAE apply Identiy
criterion = LOSS_TYPES[self.loss_type]
if self.loss_type == "cross_entropy":
assert torch.all(expected_values >= 0) and torch.all(
expected_values <= 1
)
assert torch.all(predicted_values >= 0) and torch.all(
predicted_values <= 1
)
assert isinstance(self.model.output_activation, torch.nn.Sigmoid)

# don't reduce the loss, so that we can calculate weighted loss
loss = criterion(
predicted_values.view(expected_values.shape),
expected_values,
reduction="none",
)
assert loss.shape == batch_weight.shape
loss = (loss * batch_weight).sum() / batch_weight.sum() # weighted average loss
# don't reduce the loss, so that we can calculate weighted loss
loss = criterion(
predicted_values.view(expected_values.shape),
expected_values,
reduction="none",
)
assert loss.shape == batch_weight.shape
loss = (
loss * batch_weight
).sum() / batch_weight.sum() # weighted average loss

# Optimize the NN via backpropagation
self._optimizer.zero_grad()
loss.backward()
self._optimizer.step()

# Optimize linear regression
self.model._linear_regression_layer.learn_batch(
model_ret["nn_output"].detach(),
expected_values,
batch_weight,
)
self._maybe_apply_discounting()

# Optimize the NN via backpropagation
self._optimizer.zero_grad()
loss.backward()
self._optimizer.step()
# Optimize linear regression
self.model._linear_regression_layer.learn_batch(
model_ret["nn_output"].detach(),
expected_values,
batch_weight,
)
self._maybe_apply_discounting()
predicted_values = predicted_values.detach() # detach for logging
return {
"label": expected_values,
Expand Down
20 changes: 19 additions & 1 deletion test/unit/with_pytorch/test_neural_linear_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def test_state_dict(self):
):
tt.assert_close(p1.to(p2.device), p2, rtol=0.0, atol=0.0)

def _get_null_batch(self, feature_dim: int) -> TransitionBatch:
return TransitionBatch( # cf. DisjointBanditContainer._get_null_batch
state=torch.zeros(1, feature_dim, dtype=torch.float),
action=torch.empty(1, 0, dtype=torch.float),
reward=torch.zeros(1, 1, dtype=torch.float),
weight=torch.zeros(1, 1, dtype=torch.float),
)

# currently test support mse, mae, cross_entropy
# separate loss_types into inddividual test cases to make it easier to debug.
def test_neural_linucb_mse_loss(self) -> None:
Expand Down Expand Up @@ -150,9 +158,19 @@ def neural_linucb(
weight=torch.ones(batch_size, 1),
)
losses = []
for _ in range(epochs):
for i in range(epochs):
if i == 1:
# simulate empty batch on early iter: can happen from DisjointBandit.
losses.append(
policy_learner.learn_batch(self._get_null_batch(feature_dim))[
"loss"
]
)
continue
losses.append(policy_learner.learn_batch(batch)["loss"])

if epochs >= NUM_EPOCHS:
self.assertTrue(all(not torch.isnan(x) for x in losses))
if loss_type == "mse":
self.assertGreater(1e-1, losses[-1])
elif loss_type == "mae":
Expand Down

0 comments on commit f01b97e

Please sign in to comment.