diff --git a/pearl/neural_networks/contextual_bandit/linear_regression.py b/pearl/neural_networks/contextual_bandit/linear_regression.py index 349c3489..a7cb0eb5 100644 --- a/pearl/neural_networks/contextual_bandit/linear_regression.py +++ b/pearl/neural_networks/contextual_bandit/linear_regression.py @@ -17,7 +17,9 @@ class LinearRegression(MuSigmaCBModel): - def __init__(self, feature_dim: int, l2_reg_lambda: float = 1.0) -> None: + def __init__( + self, feature_dim: int, l2_reg_lambda: float = 1.0, gamma: float = 1.0 + ) -> None: """ A linear regression model which can estimate both point prediction and uncertainty (standard delivation). @@ -30,12 +32,20 @@ def __init__(self, feature_dim: int, l2_reg_lambda: float = 1.0) -> None: feature_dim: number of features l2_reg_lambda: L2 regularization parameter + gamma: discounting multiplier (A and b get multiplied by gamma periodically, the period + is controlled by PolicyLearner). We use a simplified implementation of + https://arxiv.org/pdf/1909.09146.pdf """ super(LinearRegression, self).__init__(feature_dim=feature_dim) + self.gamma = gamma + self.l2_reg_lambda = l2_reg_lambda + assert ( + gamma > 0 and gamma <= 1 + ), f"gamma should be in (0, 1]. Got gamma={gamma} instead" self.register_buffer( "_A", - l2_reg_lambda * torch.eye(feature_dim + 1), # +1 for intercept - ) + torch.zeros(feature_dim + 1, feature_dim + 1), # +1 for intercept + ) # initializing as zeros. L2 regularization will be applied separately. self.register_buffer("_b", torch.zeros(feature_dim + 1)) self.register_buffer("_sum_weight", torch.zeros(1)) self.register_buffer( @@ -47,7 +57,10 @@ def __init__(self, feature_dim: int, l2_reg_lambda: float = 1.0) -> None: @property def A(self) -> torch.Tensor: - return self._A + # return A with L2 regularization applied + return self._A + self.l2_reg_lambda * torch.eye( + self._feature_dim + 1, device=self._A.device + ) @property def coefs(self) -> torch.Tensor: @@ -150,6 +163,22 @@ def learn_batch( self.calculate_coefs() # update coefs after updating A and b + def apply_discounting(self) -> None: + """ + Apply gamma (discountting multiplier) to A and b. + Gamma is <=1, so it reduces the effect of old data points and enabled the model to + "forget" old data and adjust to new data distribution in non-stationary environment + + A <- A * gamma + b <- b * gamma + """ + logger.info(f"Applying discounting at sum_weight={self._sum_weight}") + self._A *= self.gamma + self._b *= self.gamma + # don't dicount sum_weight because it's used to determine when to apply discounting + + self.calculate_coefs() # update coefs using new A and b + def forward(self, x: torch.Tensor) -> torch.Tensor: # x can be [batch_size, feature_dim] or [batch_size, num_arms, feature_dim] batch_size = x.shape[0] @@ -167,7 +196,7 @@ def calculate_coefs(self) -> None: Calculate coefficients based on current A and b. Save inverted A and coefficients in buffers. """ - self._inv_A = self.matrix_inv_fallback_pinv(self._A) + self._inv_A = self.matrix_inv_fallback_pinv(self.A) self._coefs = torch.matmul(self._inv_A, self._b) def calculate_sigma(self, x: torch.Tensor) -> torch.Tensor: @@ -182,4 +211,4 @@ def calculate_sigma(self, x: torch.Tensor) -> torch.Tensor: return sigma.reshape(batch_size, -1) def __str__(self) -> str: - return f"LinearRegression(A:\n{self._A}\nb:\n{self._b})" + return f"LinearRegression(A:\n{self.A}\nb:\n{self._b})" diff --git a/pearl/neural_networks/contextual_bandit/neural_linear_regression.py b/pearl/neural_networks/contextual_bandit/neural_linear_regression.py index 8e64068d..5c4329ca 100644 --- a/pearl/neural_networks/contextual_bandit/neural_linear_regression.py +++ b/pearl/neural_networks/contextual_bandit/neural_linear_regression.py @@ -26,6 +26,7 @@ def __init__( feature_dim: int, hidden_dims: List[int], # last one is the input dim for linear regression l2_reg_lambda_linear: float = 1.0, + gamma: float = 1.0, output_activation_name: str = "linear", use_batch_norm: bool = False, use_layer_norm: bool = False, @@ -43,6 +44,7 @@ def __init__( feature_dim: number of features hidden_dims: size of hidden layers in the network l2_reg_lambda_linear: L2 regularization parameter for the linear regression layer + gamma: discounting multiplier for the linear regression layer output_activation_name: output activation function name (see ACTIVATION_MAP) use_batch_norm: whether to use batch normalization use_layer_norm: whether to use layer normalization @@ -66,6 +68,7 @@ def __init__( self._linear_regression_layer = LinearRegression( feature_dim=hidden_dims[-1], l2_reg_lambda=l2_reg_lambda_linear, + gamma=gamma, ) self.output_activation: Union[ LeakyReLU, ReLU, Sigmoid, Softplus, Tanh, nn.Identity diff --git a/pearl/policy_learners/contextual_bandits/linear_bandit.py b/pearl/policy_learners/contextual_bandits/linear_bandit.py index 71b66c7c..8dc53256 100644 --- a/pearl/policy_learners/contextual_bandits/linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/linear_bandit.py @@ -43,6 +43,9 @@ def __init__( feature_dim: int, exploration_module: Optional[ExplorationModule] = None, l2_reg_lambda: float = 1.0, + gamma: float = 1.0, + apply_discounting_interval: float = 0.0, # discounting will be applied after this many + # observations (weighted) are processed. set to 0 to disable training_rounds: int = 100, batch_size: int = 128, action_representation_module: Optional[ActionRepresentationModule] = None, @@ -55,8 +58,26 @@ def __init__( action_representation_module=action_representation_module, ) self.model = LinearRegression( - feature_dim=feature_dim, l2_reg_lambda=l2_reg_lambda + feature_dim=feature_dim, l2_reg_lambda=l2_reg_lambda, gamma=gamma ) + self.apply_discounting_interval = apply_discounting_interval + self.last_sum_weight_when_discounted = 0.0 + + def _maybe_apply_discounting(self) -> None: + """ + Check if it's time to apply discounting and do so if it's time. + Discounting is applied after every N data points (weighted) are processed. + + `self.last_sum_weight_when_discounted` stores the data point counter when discounting was + last applied. + `self.model._sum_weight.item()` is the current data point counter + """ + if (self.apply_discounting_interval > 0) and ( + self.model._sum_weight.item() - self.last_sum_weight_when_discounted + >= self.apply_discounting_interval + ): + self.model.apply_discounting() + self.last_sum_weight_when_discounted = self.model._sum_weight.item() def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: """ @@ -75,6 +96,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: y=batch.reward, weight=batch.weight, ) + self._maybe_apply_discounting() predicted_values = self.model(x) return { "label": expected_values, diff --git a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py index b47e941d..b2174550 100644 --- a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py @@ -62,6 +62,8 @@ def __init__( batch_size: int = 128, learning_rate: float = 0.0003, l2_reg_lambda_linear: float = 1.0, + gamma: float = 1.0, + apply_discounting_interval: float = 0.0, # set to 0 to disable state_features_only: bool = False, loss_type: str = "mse", # one of the LOSS_TYPES names: [mse, mae, cross_entropy] output_activation_name: str = "linear", @@ -86,6 +88,7 @@ def __init__( feature_dim=feature_dim, hidden_dims=hidden_dims, l2_reg_lambda_linear=l2_reg_lambda_linear, + gamma=gamma, output_activation_name=output_activation_name, use_batch_norm=use_batch_norm, use_layer_norm=use_layer_norm, @@ -99,6 +102,27 @@ def __init__( ) self._state_features_only = state_features_only self.loss_type = loss_type + self.apply_discounting_interval = apply_discounting_interval + self.last_sum_weight_when_discounted = 0.0 + + def _maybe_apply_discounting(self) -> None: + """ + Check if it's time to apply discounting and do so if it's time. + Discounting is applied after every N data points (weighted) are processed. + + `self.last_sum_weight_when_discounted` stores the data point counter when discounting was + last applied. + `self.model._linear_regression_layer._sum_weight.item()` is the current data point counter + """ + if (self.apply_discounting_interval > 0) and ( + self.model._linear_regression_layer._sum_weight.item() + - self.last_sum_weight_when_discounted + >= self.apply_discounting_interval + ): + self.model._linear_regression_layer.apply_discounting() + self.last_sum_weight_when_discounted = ( + self.model._linear_regression_layer._sum_weight.item() + ) @property def optimizer(self) -> torch.optim.Optimizer: @@ -147,6 +171,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: expected_values, batch_weight, ) + self._maybe_apply_discounting() return { "label": expected_values, "prediction": predicted_values, diff --git a/test/unit/with_pytorch/test_linear_bandits.py b/test/unit/with_pytorch/test_linear_bandits.py index 8100e175..c1c63a39 100644 --- a/test/unit/with_pytorch/test_linear_bandits.py +++ b/test/unit/with_pytorch/test_linear_bandits.py @@ -156,7 +156,7 @@ def test_linear_ucb_sigma(self) -> None: # test sigma of policy_learner (LinUCB) features = torch.cat([batch.state, batch.action], dim=1) - A = policy_learner.model._A + A = policy_learner.model.A A_inv = torch.linalg.inv(A) features_with_ones = LinearRegression.append_ones(features) sigma = torch.sqrt( @@ -216,3 +216,34 @@ def test_linear_efficient_thompson_sampling_act(self) -> None: self.assertTrue( all(a in range(0, action_space.n) for a in selected_actions.tolist()) ) + + def test_discounting(self) -> None: + """ + Test discounting + """ + policy_learner = LinearBandit( + feature_dim=4, + exploration_module=UCBExploration(alpha=0), + l2_reg_lambda=1e-8, + gamma=0.95, + apply_discounting_interval=100.0, + ) + + num_reps = 100 + for _ in range(num_reps): + policy_learner.learn_batch(self.batch) + + self.assertLess( + policy_learner.model.A[0, 0].item(), + # pyre-fixme[58]: `*` is not supported for operand types `int` and + # `Union[bool, float, int]`. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Optional[Tensor]`. + num_reps * torch.sum(self.batch.weight).item(), + ) + self.assertLess( + policy_learner.model._b[0].item(), + # pyre-fixme[58]: `*` is not supported for operand types `int` and + # `Union[bool, float, int]`. + num_reps * torch.sum(self.batch.reward * self.batch.weight).item(), + ) diff --git a/test/unit/with_pytorch/test_neural_linear_bandits.py b/test/unit/with_pytorch/test_neural_linear_bandits.py index e6930b25..57765d13 100644 --- a/test/unit/with_pytorch/test_neural_linear_bandits.py +++ b/test/unit/with_pytorch/test_neural_linear_bandits.py @@ -182,3 +182,47 @@ def neural_linucb( subjective_state=state, available_action_space=action_space ) self.assertEqual(action.shape, (batch_size,)) + + def test_discounting(self) -> None: + """ + Test discounting + """ + feature_dim = 10 + batch_size = 100 + policy_learner = NeuralLinearBandit( + feature_dim=feature_dim, + hidden_dims=[16, 16], + learning_rate=0.01, + exploration_module=UCBExploration(alpha=0.1), + use_skip_connections=True, + gamma=0.95, + apply_discounting_interval=100.0, + ) + state = torch.randn(batch_size, 3) + action = torch.randn(batch_size, feature_dim - 3) + batch = TransitionBatch( + state=state, + action=action, + # y = sum of state + sum of action + reward=state.sum(-1, keepdim=True) + action.sum(-1, keepdim=True), + weight=torch.ones(batch_size, 1), + ) + + num_reps = 100 + for _ in range(num_reps): + policy_learner.learn_batch(batch) + + self.assertLess( + policy_learner.model._linear_regression_layer.A[0, 0].item(), + # pyre-fixme[58]: `*` is not supported for operand types `int` and + # `Union[bool, float, int]`. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Optional[Tensor]`. + num_reps * torch.sum(batch.weight).item(), + ) + self.assertLess( + policy_learner.model._linear_regression_layer._b[0].item(), + # pyre-fixme[58]: `*` is not supported for operand types `int` and + # `Union[bool, float, int]`. + num_reps * torch.sum(batch.reward * batch.weight).item(), + )