From a8d96bbd236730b67c39f86b8f31bbb213854f79 Mon Sep 17 00:00:00 2001 From: Alex Nikulkov Date: Thu, 1 Feb 2024 13:38:57 -0800 Subject: [PATCH] Add discounting to LinUCB-based CBs in Pearl Summary: Adding a discounting multiplier to LinUCB and Neural LinUCB. This will allow the model to forget old data (in Neural LinUCB just the LinUCB layer forgets, not the NN part) by downweighting it. 2 parameters control discounting: 1. `gamma` - this is the multiplier which is applied to `A` and `b` in LinUCB every time discounting is applied 2. `apply_discounting_interval` - this parameter controls how often discounting is applied (every `apply_discounting_interval` data points). If data are weighted, the weights are taken into account (e.g. a data point with weight 10 counts as 10 data points). Support for discounting was added to: 1. `LinearBandit` and `NeuralLinearBandit` classes in Pearl 2. `ap_container` and `ads_creative_ranking` projects in APS Reviewed By: Yonathae Differential Revision: D52890963 fbshipit-source-id: 224c678a0be2362b871ca9538a0cee719a66f5d5 --- .../contextual_bandit/linear_regression.py | 41 ++++++++++++++--- .../neural_linear_regression.py | 3 ++ .../contextual_bandits/linear_bandit.py | 24 +++++++++- .../neural_linear_bandit.py | 25 +++++++++++ test/unit/with_pytorch/test_linear_bandits.py | 33 +++++++++++++- .../test_neural_linear_bandits.py | 44 +++++++++++++++++++ 6 files changed, 162 insertions(+), 8 deletions(-) diff --git a/pearl/neural_networks/contextual_bandit/linear_regression.py b/pearl/neural_networks/contextual_bandit/linear_regression.py index 349c3489..a7cb0eb5 100644 --- a/pearl/neural_networks/contextual_bandit/linear_regression.py +++ b/pearl/neural_networks/contextual_bandit/linear_regression.py @@ -17,7 +17,9 @@ class LinearRegression(MuSigmaCBModel): - def __init__(self, feature_dim: int, l2_reg_lambda: float = 1.0) -> None: + def __init__( + self, feature_dim: int, l2_reg_lambda: float = 1.0, gamma: float = 1.0 + ) -> None: """ A linear regression model which can estimate both point prediction and uncertainty (standard delivation). @@ -30,12 +32,20 @@ def __init__(self, feature_dim: int, l2_reg_lambda: float = 1.0) -> None: feature_dim: number of features l2_reg_lambda: L2 regularization parameter + gamma: discounting multiplier (A and b get multiplied by gamma periodically, the period + is controlled by PolicyLearner). We use a simplified implementation of + https://arxiv.org/pdf/1909.09146.pdf """ super(LinearRegression, self).__init__(feature_dim=feature_dim) + self.gamma = gamma + self.l2_reg_lambda = l2_reg_lambda + assert ( + gamma > 0 and gamma <= 1 + ), f"gamma should be in (0, 1]. Got gamma={gamma} instead" self.register_buffer( "_A", - l2_reg_lambda * torch.eye(feature_dim + 1), # +1 for intercept - ) + torch.zeros(feature_dim + 1, feature_dim + 1), # +1 for intercept + ) # initializing as zeros. L2 regularization will be applied separately. self.register_buffer("_b", torch.zeros(feature_dim + 1)) self.register_buffer("_sum_weight", torch.zeros(1)) self.register_buffer( @@ -47,7 +57,10 @@ def __init__(self, feature_dim: int, l2_reg_lambda: float = 1.0) -> None: @property def A(self) -> torch.Tensor: - return self._A + # return A with L2 regularization applied + return self._A + self.l2_reg_lambda * torch.eye( + self._feature_dim + 1, device=self._A.device + ) @property def coefs(self) -> torch.Tensor: @@ -150,6 +163,22 @@ def learn_batch( self.calculate_coefs() # update coefs after updating A and b + def apply_discounting(self) -> None: + """ + Apply gamma (discountting multiplier) to A and b. + Gamma is <=1, so it reduces the effect of old data points and enabled the model to + "forget" old data and adjust to new data distribution in non-stationary environment + + A <- A * gamma + b <- b * gamma + """ + logger.info(f"Applying discounting at sum_weight={self._sum_weight}") + self._A *= self.gamma + self._b *= self.gamma + # don't dicount sum_weight because it's used to determine when to apply discounting + + self.calculate_coefs() # update coefs using new A and b + def forward(self, x: torch.Tensor) -> torch.Tensor: # x can be [batch_size, feature_dim] or [batch_size, num_arms, feature_dim] batch_size = x.shape[0] @@ -167,7 +196,7 @@ def calculate_coefs(self) -> None: Calculate coefficients based on current A and b. Save inverted A and coefficients in buffers. """ - self._inv_A = self.matrix_inv_fallback_pinv(self._A) + self._inv_A = self.matrix_inv_fallback_pinv(self.A) self._coefs = torch.matmul(self._inv_A, self._b) def calculate_sigma(self, x: torch.Tensor) -> torch.Tensor: @@ -182,4 +211,4 @@ def calculate_sigma(self, x: torch.Tensor) -> torch.Tensor: return sigma.reshape(batch_size, -1) def __str__(self) -> str: - return f"LinearRegression(A:\n{self._A}\nb:\n{self._b})" + return f"LinearRegression(A:\n{self.A}\nb:\n{self._b})" diff --git a/pearl/neural_networks/contextual_bandit/neural_linear_regression.py b/pearl/neural_networks/contextual_bandit/neural_linear_regression.py index 8e64068d..5c4329ca 100644 --- a/pearl/neural_networks/contextual_bandit/neural_linear_regression.py +++ b/pearl/neural_networks/contextual_bandit/neural_linear_regression.py @@ -26,6 +26,7 @@ def __init__( feature_dim: int, hidden_dims: List[int], # last one is the input dim for linear regression l2_reg_lambda_linear: float = 1.0, + gamma: float = 1.0, output_activation_name: str = "linear", use_batch_norm: bool = False, use_layer_norm: bool = False, @@ -43,6 +44,7 @@ def __init__( feature_dim: number of features hidden_dims: size of hidden layers in the network l2_reg_lambda_linear: L2 regularization parameter for the linear regression layer + gamma: discounting multiplier for the linear regression layer output_activation_name: output activation function name (see ACTIVATION_MAP) use_batch_norm: whether to use batch normalization use_layer_norm: whether to use layer normalization @@ -66,6 +68,7 @@ def __init__( self._linear_regression_layer = LinearRegression( feature_dim=hidden_dims[-1], l2_reg_lambda=l2_reg_lambda_linear, + gamma=gamma, ) self.output_activation: Union[ LeakyReLU, ReLU, Sigmoid, Softplus, Tanh, nn.Identity diff --git a/pearl/policy_learners/contextual_bandits/linear_bandit.py b/pearl/policy_learners/contextual_bandits/linear_bandit.py index 71b66c7c..8dc53256 100644 --- a/pearl/policy_learners/contextual_bandits/linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/linear_bandit.py @@ -43,6 +43,9 @@ def __init__( feature_dim: int, exploration_module: Optional[ExplorationModule] = None, l2_reg_lambda: float = 1.0, + gamma: float = 1.0, + apply_discounting_interval: float = 0.0, # discounting will be applied after this many + # observations (weighted) are processed. set to 0 to disable training_rounds: int = 100, batch_size: int = 128, action_representation_module: Optional[ActionRepresentationModule] = None, @@ -55,8 +58,26 @@ def __init__( action_representation_module=action_representation_module, ) self.model = LinearRegression( - feature_dim=feature_dim, l2_reg_lambda=l2_reg_lambda + feature_dim=feature_dim, l2_reg_lambda=l2_reg_lambda, gamma=gamma ) + self.apply_discounting_interval = apply_discounting_interval + self.last_sum_weight_when_discounted = 0.0 + + def _maybe_apply_discounting(self) -> None: + """ + Check if it's time to apply discounting and do so if it's time. + Discounting is applied after every N data points (weighted) are processed. + + `self.last_sum_weight_when_discounted` stores the data point counter when discounting was + last applied. + `self.model._sum_weight.item()` is the current data point counter + """ + if (self.apply_discounting_interval > 0) and ( + self.model._sum_weight.item() - self.last_sum_weight_when_discounted + >= self.apply_discounting_interval + ): + self.model.apply_discounting() + self.last_sum_weight_when_discounted = self.model._sum_weight.item() def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: """ @@ -75,6 +96,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: y=batch.reward, weight=batch.weight, ) + self._maybe_apply_discounting() predicted_values = self.model(x) return { "label": expected_values, diff --git a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py index b47e941d..b2174550 100644 --- a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py @@ -62,6 +62,8 @@ def __init__( batch_size: int = 128, learning_rate: float = 0.0003, l2_reg_lambda_linear: float = 1.0, + gamma: float = 1.0, + apply_discounting_interval: float = 0.0, # set to 0 to disable state_features_only: bool = False, loss_type: str = "mse", # one of the LOSS_TYPES names: [mse, mae, cross_entropy] output_activation_name: str = "linear", @@ -86,6 +88,7 @@ def __init__( feature_dim=feature_dim, hidden_dims=hidden_dims, l2_reg_lambda_linear=l2_reg_lambda_linear, + gamma=gamma, output_activation_name=output_activation_name, use_batch_norm=use_batch_norm, use_layer_norm=use_layer_norm, @@ -99,6 +102,27 @@ def __init__( ) self._state_features_only = state_features_only self.loss_type = loss_type + self.apply_discounting_interval = apply_discounting_interval + self.last_sum_weight_when_discounted = 0.0 + + def _maybe_apply_discounting(self) -> None: + """ + Check if it's time to apply discounting and do so if it's time. + Discounting is applied after every N data points (weighted) are processed. + + `self.last_sum_weight_when_discounted` stores the data point counter when discounting was + last applied. + `self.model._linear_regression_layer._sum_weight.item()` is the current data point counter + """ + if (self.apply_discounting_interval > 0) and ( + self.model._linear_regression_layer._sum_weight.item() + - self.last_sum_weight_when_discounted + >= self.apply_discounting_interval + ): + self.model._linear_regression_layer.apply_discounting() + self.last_sum_weight_when_discounted = ( + self.model._linear_regression_layer._sum_weight.item() + ) @property def optimizer(self) -> torch.optim.Optimizer: @@ -147,6 +171,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: expected_values, batch_weight, ) + self._maybe_apply_discounting() return { "label": expected_values, "prediction": predicted_values, diff --git a/test/unit/with_pytorch/test_linear_bandits.py b/test/unit/with_pytorch/test_linear_bandits.py index 8100e175..c1c63a39 100644 --- a/test/unit/with_pytorch/test_linear_bandits.py +++ b/test/unit/with_pytorch/test_linear_bandits.py @@ -156,7 +156,7 @@ def test_linear_ucb_sigma(self) -> None: # test sigma of policy_learner (LinUCB) features = torch.cat([batch.state, batch.action], dim=1) - A = policy_learner.model._A + A = policy_learner.model.A A_inv = torch.linalg.inv(A) features_with_ones = LinearRegression.append_ones(features) sigma = torch.sqrt( @@ -216,3 +216,34 @@ def test_linear_efficient_thompson_sampling_act(self) -> None: self.assertTrue( all(a in range(0, action_space.n) for a in selected_actions.tolist()) ) + + def test_discounting(self) -> None: + """ + Test discounting + """ + policy_learner = LinearBandit( + feature_dim=4, + exploration_module=UCBExploration(alpha=0), + l2_reg_lambda=1e-8, + gamma=0.95, + apply_discounting_interval=100.0, + ) + + num_reps = 100 + for _ in range(num_reps): + policy_learner.learn_batch(self.batch) + + self.assertLess( + policy_learner.model.A[0, 0].item(), + # pyre-fixme[58]: `*` is not supported for operand types `int` and + # `Union[bool, float, int]`. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Optional[Tensor]`. + num_reps * torch.sum(self.batch.weight).item(), + ) + self.assertLess( + policy_learner.model._b[0].item(), + # pyre-fixme[58]: `*` is not supported for operand types `int` and + # `Union[bool, float, int]`. + num_reps * torch.sum(self.batch.reward * self.batch.weight).item(), + ) diff --git a/test/unit/with_pytorch/test_neural_linear_bandits.py b/test/unit/with_pytorch/test_neural_linear_bandits.py index e6930b25..57765d13 100644 --- a/test/unit/with_pytorch/test_neural_linear_bandits.py +++ b/test/unit/with_pytorch/test_neural_linear_bandits.py @@ -182,3 +182,47 @@ def neural_linucb( subjective_state=state, available_action_space=action_space ) self.assertEqual(action.shape, (batch_size,)) + + def test_discounting(self) -> None: + """ + Test discounting + """ + feature_dim = 10 + batch_size = 100 + policy_learner = NeuralLinearBandit( + feature_dim=feature_dim, + hidden_dims=[16, 16], + learning_rate=0.01, + exploration_module=UCBExploration(alpha=0.1), + use_skip_connections=True, + gamma=0.95, + apply_discounting_interval=100.0, + ) + state = torch.randn(batch_size, 3) + action = torch.randn(batch_size, feature_dim - 3) + batch = TransitionBatch( + state=state, + action=action, + # y = sum of state + sum of action + reward=state.sum(-1, keepdim=True) + action.sum(-1, keepdim=True), + weight=torch.ones(batch_size, 1), + ) + + num_reps = 100 + for _ in range(num_reps): + policy_learner.learn_batch(batch) + + self.assertLess( + policy_learner.model._linear_regression_layer.A[0, 0].item(), + # pyre-fixme[58]: `*` is not supported for operand types `int` and + # `Union[bool, float, int]`. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Optional[Tensor]`. + num_reps * torch.sum(batch.weight).item(), + ) + self.assertLess( + policy_learner.model._linear_regression_layer._b[0].item(), + # pyre-fixme[58]: `*` is not supported for operand types `int` and + # `Union[bool, float, int]`. + num_reps * torch.sum(batch.reward * batch.weight).item(), + )