Skip to content

Commit

Permalink
Don't skip empty batches in disjoint CB training
Browse files Browse the repository at this point in the history
Summary:
Due to asymmetrical distributed data loading some of the batches might not have all arms present. This breaks distributed training because some of the all_reduce calls are unmatched. This diff removes skipping of empty batches, so that for each batch and each arm we make an all_reduce call (0 values are used for actions which aren't present in the batch).

Changes in this diff:
1. Remove skipping of empty batches in disjoint CB. When the batch is empty, replace it with a null batch. Null batch has a single element with 0 weight and dummy state/action/reward values.

Differential Revision: D53028375

fbshipit-source-id: ea282267a01d2e581b4779d8f49383e0ac7da03d
  • Loading branch information
Alex Nikulkov authored and facebook-github-bot committed Jan 31, 2024
1 parent e80bfb1 commit 98f8435
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 32 deletions.
86 changes: 54 additions & 32 deletions pearl/policy_learners/contextual_bandits/disjoint_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
self._arm_bandits: torch.nn.ModuleList = torch.nn.ModuleList(arm_bandits)
self._n_arms: int = len(arm_bandits)
self._state_features_only = state_features_only
self._null_batch: Optional[TransitionBatch] = None

@property
def n_arms(self) -> int:
Expand All @@ -82,37 +83,61 @@ def _partition_batch_by_arm(self, batch: TransitionBatch) -> List[TransitionBatc
# mask of observations for this arm
# assume action indices
mask = batch.action[:, 0] == arm
if batch.state.ndim == 2:
# shape: (batch_size, feature_size)
# same features for all arms
state = batch.state
elif batch.state.ndim == 3:
# shape: (batch_size, num_arms, feature_size)
# different features for each arm
assert (
batch.state.shape[1] == self.n_arms
), "For 3D state, 2nd dimension must be equal to number of arms"
state = batch.state[:, arm, :]
batches.append(
TransitionBatch(
state=state[mask],
reward=batch.reward[mask],
weight=batch.weight[mask]
if batch.weight is not None
else torch.ones_like(mask, dtype=torch.float),
# empty action features since disjoint model used
# action as index of per-arm model
# if arms need different features, use 3D `state` instead
action=torch.empty(
int(mask.sum().item()),
0,
dtype=torch.float,
device=batch.device,
),
).to(batch.device)
)
if mask.sum().item() == 0:
# no observations for this arm, use null batch
batches.append(self._get_null_batch(batch))
else:
if batch.state.ndim == 2:
# shape: (batch_size, feature_size)
# same features for all arms
state = batch.state
elif batch.state.ndim == 3:
# shape: (batch_size, num_arms, feature_size)
# different features for each arm
assert (
batch.state.shape[1] == self.n_arms
), "For 3D state, 2nd dimension must be equal to number of arms"
state = batch.state[:, arm, :]
batches.append(
TransitionBatch(
state=state[mask],
reward=batch.reward[mask],
weight=batch.weight[mask]
if batch.weight is not None
else torch.ones_like(mask, dtype=torch.float),
# empty action features since disjoint model used
# action as index of per-arm model
# if arms need different features, use 3D `state` instead
action=torch.empty(
int(mask.sum().item()),
0,
dtype=torch.float,
device=batch.device,
),
).to(batch.device)
)
return batches

def _get_null_batch(self, batch: TransitionBatch) -> TransitionBatch:
# null batch has 1 element, but 0 weight
if self._null_batch is None:
self._null_batch = TransitionBatch(
state=torch.zeros(
1, batch.state.shape[-1], dtype=torch.float, device=batch.device
),
reward=torch.zeros(1, 1, dtype=torch.float, device=batch.device),
weight=torch.zeros(1, 1, dtype=torch.float, device=batch.device),
action=torch.empty(
1,
0,
dtype=torch.float,
device=batch.device,
),
).to(batch.device)
null_batch = self._null_batch
assert null_batch is not None
return null_batch

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
"""
action_idx determines which of the models the observation will be routed to.
Expand All @@ -124,9 +149,6 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
for i, (arm_bandit, arm_batch) in enumerate(
zip(self._arm_bandits, arm_batches)
):
if len(arm_batch) == 0:
# skip updates if batch has no observations for this arm
continue
returns.update(
{
f"arm_{i}_{k}": v
Expand Down
25 changes: 25 additions & 0 deletions test/unit/with_pytorch/test_disjoint_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,28 @@ def test_get_scores(self) -> None:
expected_scores.append(mus + alpha * sigmas)
expected_scores = torch.cat(expected_scores, dim=1)
self.assertTrue(torch.allclose(scores, expected_scores, atol=1e-1))

def test_learn_batch_arm_subset(self) -> None:
# test that learn_batch still works when the batch has a subset of arms

policy_learner = copy.deepcopy(self.policy_learner)

# action 0 is missing from the batch
batch = TransitionBatch(
state=torch.tensor(
[
[2.0, 3.0],
[1.0, 5.0],
[0.5, 3.0],
[1.8, 5.1],
]
),
action=torch.tensor(
[[1], [1], [2], [2]],
),
reward=torch.tensor([7.0, 7.0, 7.0, 13.8]).unsqueeze(-1),
weight=torch.tensor([1.0, 1.0, 1.0, 1.0]).unsqueeze(-1),
)

# learn batch, make sure this doesn't throw an error
policy_learner.learn_batch(batch)

0 comments on commit 98f8435

Please sign in to comment.