Skip to content

Commit

Permalink
Add discounting to LinUCB-based CBs in Pearl
Browse files Browse the repository at this point in the history
Summary:
Adding a discounting multiplier to LinUCB and Neural LinUCB. This will allow the model to forget old data (in Neural LinUCB just the LinUCB layer forgets, not the NN part) by downweighting it. 2 parameters control discounting:
1. `gamma` - this is the multiplier which is applied to `A` and `b` in LinUCB every time discounting is applied
2. `apply_discounting_interval` - this parameter controls how often discounting is applied (every `apply_discounting_interval` data points). If data are weighted, the weights are taken into account (e.g. a data point with weight 10 counts as 10 data points).

Support for discounting was added to:
1. `LinearBandit` and `NeuralLinearBandit` classes in Pearl
2. `ap_container` and `ads_creative_ranking` projects in APS

Reviewed By: Yonathae

Differential Revision: D52890963

fbshipit-source-id: 224c678a0be2362b871ca9538a0cee719a66f5d5
  • Loading branch information
Alex Nikulkov authored and facebook-github-bot committed Feb 1, 2024
1 parent 322992e commit a8d96bb
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 8 deletions.
41 changes: 35 additions & 6 deletions pearl/neural_networks/contextual_bandit/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@


class LinearRegression(MuSigmaCBModel):
def __init__(self, feature_dim: int, l2_reg_lambda: float = 1.0) -> None:
def __init__(
self, feature_dim: int, l2_reg_lambda: float = 1.0, gamma: float = 1.0
) -> None:
"""
A linear regression model which can estimate both point prediction and uncertainty
(standard delivation).
Expand All @@ -30,12 +32,20 @@ def __init__(self, feature_dim: int, l2_reg_lambda: float = 1.0) -> None:
feature_dim: number of features
l2_reg_lambda: L2 regularization parameter
gamma: discounting multiplier (A and b get multiplied by gamma periodically, the period
is controlled by PolicyLearner). We use a simplified implementation of
https://arxiv.org/pdf/1909.09146.pdf
"""
super(LinearRegression, self).__init__(feature_dim=feature_dim)
self.gamma = gamma
self.l2_reg_lambda = l2_reg_lambda
assert (
gamma > 0 and gamma <= 1
), f"gamma should be in (0, 1]. Got gamma={gamma} instead"
self.register_buffer(
"_A",
l2_reg_lambda * torch.eye(feature_dim + 1), # +1 for intercept
)
torch.zeros(feature_dim + 1, feature_dim + 1), # +1 for intercept
) # initializing as zeros. L2 regularization will be applied separately.
self.register_buffer("_b", torch.zeros(feature_dim + 1))
self.register_buffer("_sum_weight", torch.zeros(1))
self.register_buffer(
Expand All @@ -47,7 +57,10 @@ def __init__(self, feature_dim: int, l2_reg_lambda: float = 1.0) -> None:

@property
def A(self) -> torch.Tensor:
return self._A
# return A with L2 regularization applied
return self._A + self.l2_reg_lambda * torch.eye(
self._feature_dim + 1, device=self._A.device
)

@property
def coefs(self) -> torch.Tensor:
Expand Down Expand Up @@ -150,6 +163,22 @@ def learn_batch(

self.calculate_coefs() # update coefs after updating A and b

def apply_discounting(self) -> None:
"""
Apply gamma (discountting multiplier) to A and b.
Gamma is <=1, so it reduces the effect of old data points and enabled the model to
"forget" old data and adjust to new data distribution in non-stationary environment
A <- A * gamma
b <- b * gamma
"""
logger.info(f"Applying discounting at sum_weight={self._sum_weight}")
self._A *= self.gamma
self._b *= self.gamma
# don't dicount sum_weight because it's used to determine when to apply discounting

self.calculate_coefs() # update coefs using new A and b

def forward(self, x: torch.Tensor) -> torch.Tensor:
# x can be [batch_size, feature_dim] or [batch_size, num_arms, feature_dim]
batch_size = x.shape[0]
Expand All @@ -167,7 +196,7 @@ def calculate_coefs(self) -> None:
Calculate coefficients based on current A and b.
Save inverted A and coefficients in buffers.
"""
self._inv_A = self.matrix_inv_fallback_pinv(self._A)
self._inv_A = self.matrix_inv_fallback_pinv(self.A)
self._coefs = torch.matmul(self._inv_A, self._b)

def calculate_sigma(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -182,4 +211,4 @@ def calculate_sigma(self, x: torch.Tensor) -> torch.Tensor:
return sigma.reshape(batch_size, -1)

def __str__(self) -> str:
return f"LinearRegression(A:\n{self._A}\nb:\n{self._b})"
return f"LinearRegression(A:\n{self.A}\nb:\n{self._b})"
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
feature_dim: int,
hidden_dims: List[int], # last one is the input dim for linear regression
l2_reg_lambda_linear: float = 1.0,
gamma: float = 1.0,
output_activation_name: str = "linear",
use_batch_norm: bool = False,
use_layer_norm: bool = False,
Expand All @@ -43,6 +44,7 @@ def __init__(
feature_dim: number of features
hidden_dims: size of hidden layers in the network
l2_reg_lambda_linear: L2 regularization parameter for the linear regression layer
gamma: discounting multiplier for the linear regression layer
output_activation_name: output activation function name (see ACTIVATION_MAP)
use_batch_norm: whether to use batch normalization
use_layer_norm: whether to use layer normalization
Expand All @@ -66,6 +68,7 @@ def __init__(
self._linear_regression_layer = LinearRegression(
feature_dim=hidden_dims[-1],
l2_reg_lambda=l2_reg_lambda_linear,
gamma=gamma,
)
self.output_activation: Union[
LeakyReLU, ReLU, Sigmoid, Softplus, Tanh, nn.Identity
Expand Down
24 changes: 23 additions & 1 deletion pearl/policy_learners/contextual_bandits/linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(
feature_dim: int,
exploration_module: Optional[ExplorationModule] = None,
l2_reg_lambda: float = 1.0,
gamma: float = 1.0,
apply_discounting_interval: float = 0.0, # discounting will be applied after this many
# observations (weighted) are processed. set to 0 to disable
training_rounds: int = 100,
batch_size: int = 128,
action_representation_module: Optional[ActionRepresentationModule] = None,
Expand All @@ -55,8 +58,26 @@ def __init__(
action_representation_module=action_representation_module,
)
self.model = LinearRegression(
feature_dim=feature_dim, l2_reg_lambda=l2_reg_lambda
feature_dim=feature_dim, l2_reg_lambda=l2_reg_lambda, gamma=gamma
)
self.apply_discounting_interval = apply_discounting_interval
self.last_sum_weight_when_discounted = 0.0

def _maybe_apply_discounting(self) -> None:
"""
Check if it's time to apply discounting and do so if it's time.
Discounting is applied after every N data points (weighted) are processed.
`self.last_sum_weight_when_discounted` stores the data point counter when discounting was
last applied.
`self.model._sum_weight.item()` is the current data point counter
"""
if (self.apply_discounting_interval > 0) and (
self.model._sum_weight.item() - self.last_sum_weight_when_discounted
>= self.apply_discounting_interval
):
self.model.apply_discounting()
self.last_sum_weight_when_discounted = self.model._sum_weight.item()

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
"""
Expand All @@ -75,6 +96,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
y=batch.reward,
weight=batch.weight,
)
self._maybe_apply_discounting()
predicted_values = self.model(x)
return {
"label": expected_values,
Expand Down
25 changes: 25 additions & 0 deletions pearl/policy_learners/contextual_bandits/neural_linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(
batch_size: int = 128,
learning_rate: float = 0.0003,
l2_reg_lambda_linear: float = 1.0,
gamma: float = 1.0,
apply_discounting_interval: float = 0.0, # set to 0 to disable
state_features_only: bool = False,
loss_type: str = "mse", # one of the LOSS_TYPES names: [mse, mae, cross_entropy]
output_activation_name: str = "linear",
Expand All @@ -86,6 +88,7 @@ def __init__(
feature_dim=feature_dim,
hidden_dims=hidden_dims,
l2_reg_lambda_linear=l2_reg_lambda_linear,
gamma=gamma,
output_activation_name=output_activation_name,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
Expand All @@ -99,6 +102,27 @@ def __init__(
)
self._state_features_only = state_features_only
self.loss_type = loss_type
self.apply_discounting_interval = apply_discounting_interval
self.last_sum_weight_when_discounted = 0.0

def _maybe_apply_discounting(self) -> None:
"""
Check if it's time to apply discounting and do so if it's time.
Discounting is applied after every N data points (weighted) are processed.
`self.last_sum_weight_when_discounted` stores the data point counter when discounting was
last applied.
`self.model._linear_regression_layer._sum_weight.item()` is the current data point counter
"""
if (self.apply_discounting_interval > 0) and (
self.model._linear_regression_layer._sum_weight.item()
- self.last_sum_weight_when_discounted
>= self.apply_discounting_interval
):
self.model._linear_regression_layer.apply_discounting()
self.last_sum_weight_when_discounted = (
self.model._linear_regression_layer._sum_weight.item()
)

@property
def optimizer(self) -> torch.optim.Optimizer:
Expand Down Expand Up @@ -147,6 +171,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
expected_values,
batch_weight,
)
self._maybe_apply_discounting()
return {
"label": expected_values,
"prediction": predicted_values,
Expand Down
33 changes: 32 additions & 1 deletion test/unit/with_pytorch/test_linear_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_linear_ucb_sigma(self) -> None:

# test sigma of policy_learner (LinUCB)
features = torch.cat([batch.state, batch.action], dim=1)
A = policy_learner.model._A
A = policy_learner.model.A
A_inv = torch.linalg.inv(A)
features_with_ones = LinearRegression.append_ones(features)
sigma = torch.sqrt(
Expand Down Expand Up @@ -216,3 +216,34 @@ def test_linear_efficient_thompson_sampling_act(self) -> None:
self.assertTrue(
all(a in range(0, action_space.n) for a in selected_actions.tolist())
)

def test_discounting(self) -> None:
"""
Test discounting
"""
policy_learner = LinearBandit(
feature_dim=4,
exploration_module=UCBExploration(alpha=0),
l2_reg_lambda=1e-8,
gamma=0.95,
apply_discounting_interval=100.0,
)

num_reps = 100
for _ in range(num_reps):
policy_learner.learn_batch(self.batch)

self.assertLess(
policy_learner.model.A[0, 0].item(),
# pyre-fixme[58]: `*` is not supported for operand types `int` and
# `Union[bool, float, int]`.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Optional[Tensor]`.
num_reps * torch.sum(self.batch.weight).item(),
)
self.assertLess(
policy_learner.model._b[0].item(),
# pyre-fixme[58]: `*` is not supported for operand types `int` and
# `Union[bool, float, int]`.
num_reps * torch.sum(self.batch.reward * self.batch.weight).item(),
)
44 changes: 44 additions & 0 deletions test/unit/with_pytorch/test_neural_linear_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,47 @@ def neural_linucb(
subjective_state=state, available_action_space=action_space
)
self.assertEqual(action.shape, (batch_size,))

def test_discounting(self) -> None:
"""
Test discounting
"""
feature_dim = 10
batch_size = 100
policy_learner = NeuralLinearBandit(
feature_dim=feature_dim,
hidden_dims=[16, 16],
learning_rate=0.01,
exploration_module=UCBExploration(alpha=0.1),
use_skip_connections=True,
gamma=0.95,
apply_discounting_interval=100.0,
)
state = torch.randn(batch_size, 3)
action = torch.randn(batch_size, feature_dim - 3)
batch = TransitionBatch(
state=state,
action=action,
# y = sum of state + sum of action
reward=state.sum(-1, keepdim=True) + action.sum(-1, keepdim=True),
weight=torch.ones(batch_size, 1),
)

num_reps = 100
for _ in range(num_reps):
policy_learner.learn_batch(batch)

self.assertLess(
policy_learner.model._linear_regression_layer.A[0, 0].item(),
# pyre-fixme[58]: `*` is not supported for operand types `int` and
# `Union[bool, float, int]`.
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Optional[Tensor]`.
num_reps * torch.sum(batch.weight).item(),
)
self.assertLess(
policy_learner.model._linear_regression_layer._b[0].item(),
# pyre-fixme[58]: `*` is not supported for operand types `int` and
# `Union[bool, float, int]`.
num_reps * torch.sum(batch.reward * batch.weight).item(),
)

0 comments on commit a8d96bb

Please sign in to comment.