Skip to content

Commit

Permalink
Light cleanup of Contextual Bandits code
Browse files Browse the repository at this point in the history
Summary: Just a couple of renames and comments for improving Contextual Bandits code.

Reviewed By: yiwan-rl

Differential Revision: D55908833

fbshipit-source-id: f0891396e83d96041ed89452078eaa932cf3670e
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed Apr 11, 2024
1 parent eb21bb0 commit b0f419b
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
ExplorationModule,
ExplorationType,
)
from pearl.utils.functional_utils.learning.action_utils import get_model_actions
from pearl.utils.functional_utils.learning.action_utils import (
get_model_action_index_batch,
)
from pearl.utils.tensor_like import assert_is_tensor_like


Expand Down Expand Up @@ -66,8 +68,27 @@ def act(
representation=representation,
) # shape: (batch_size, action_count)
scores = assert_is_tensor_like(scores)
selected_action = get_model_actions(scores, action_availability_mask)
return selected_action.squeeze(-1)
action_index_batch = get_model_action_index_batch(
scores, action_availability_mask
)
return action_index_batch.squeeze(-1)
# FIXME: the squeeze(-1) is a hack.
# It is used to get rid of the batch dimension if the batch has a
# single element. For example, if action_index_batch is
# torch.tensor([0]), then the result will be the batch-less index 0.
# The rationale is that if the batch has a single element, then
# subject_state was batchless and self.get_score introduced a batch
# dimension (for uniformity and convenience of operations, which can
# then all assume batch form), so the batch dimension should be removed.
# The problem with this approach is that it is heuristic and not
# correct in all cases. For example, if subject_state is *not* batchless
# but has a single element, then the returned value should be a
# single-element batch containing one index, but in this case
# squeeze will incorrectly remove the batch dimension.
# The correct approach should be that all functions manipulate tensors
# in the same way PyTorch modules do, namely accepting input that
# may have a batch dimension or not, and have all following tensors
# mirroring that.

@abstractmethod
def get_scores(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ class FastCBExploration(SquareCBExploration):
See: https://arxiv.org/abs/2107.02237 for details.
Assumptions: Reward is bounded. For the update rule to be valid we require bounded rewards.
User can modify lower and upper bounds of the reward by setting reward_lb and reward_ub. clamp_values is set to True by default.
User can modify lower and upper bounds of the reward by setting reward_lb and reward_ub.
Clamp_values is set to True by default.
Args:
gamma (float): controls the exploration-exploitation tradeoff;
Expand Down
2 changes: 1 addition & 1 deletion pearl/utils/functional_utils/learning/action_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def argmax_random_tie_breaks(
return argmax_indices


def get_model_actions(
def get_model_action_index_batch(
scores: Tensor,
mask: Optional[Tensor] = None,
randomize_ties: bool = False,
Expand Down
6 changes: 3 additions & 3 deletions test/unit/with_pytorch/test_action_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from pearl.utils.functional_utils.learning.action_utils import (
argmax_random_tie_breaks,
get_model_actions,
get_model_action_index_batch,
)


Expand Down Expand Up @@ -77,7 +77,7 @@ def test_get_model_actions_randomize(self) -> None:
argmax_values_returned = {0: set(), 1: set(), 2: set(), 3: set()}
for _ in range(1000):
# repeat many times since the function is stochastic
argmax = get_model_actions(scores, mask, randomize_ties=True)
argmax = get_model_action_index_batch(scores, mask, randomize_ties=True)
# make sure argmax returns one of the max element indices
argmax_values_returned[0].add(argmax[0].item())
argmax_values_returned[1].add(argmax[1].item())
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_get_model_actions_not_randomize(self) -> None:
argmax_values_returned = {0: set(), 1: set(), 2: set(), 3: set()}
for _ in range(1000):
# repeat many times since the function is stochastic
argmax = get_model_actions(scores, mask, randomize_ties=False)
argmax = get_model_action_index_batch(scores, mask, randomize_ties=False)
# make sure argmax returns one of the max element indices
argmax_values_returned[0].add(argmax[0].item())
argmax_values_returned[1].add(argmax[1].item())
Expand Down

0 comments on commit b0f419b

Please sign in to comment.