From 670856d06ebbf6485636e61de21a3df29cd4bf6f Mon Sep 17 00:00:00 2001 From: Alex Nikulkov Date: Thu, 16 May 2024 16:42:19 -0700 Subject: [PATCH] Add option to force use of pinv in LinUCB Summary: Adding an options to force pseudo-inverse to be used by default in LinUCB. This is useful for cases where the features are strongly correlated, so the design matrix is ill-conditioned. The regular inverse is faster, it should be used when matrix inversion is the bottleneck and the design matrix is well-conditioned. Reviewed By: zxpmirror1994 Differential Revision: D57462061 --- .../contextual_bandit/linear_regression.py | 34 +++++++++++++------ .../neural_linear_regression.py | 2 ++ .../contextual_bandits/linear_bandit.py | 6 +++- .../neural_linear_bandit.py | 2 ++ 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/pearl/neural_networks/contextual_bandit/linear_regression.py b/pearl/neural_networks/contextual_bandit/linear_regression.py index 318180f9..29fb8f03 100644 --- a/pearl/neural_networks/contextual_bandit/linear_regression.py +++ b/pearl/neural_networks/contextual_bandit/linear_regression.py @@ -20,7 +20,11 @@ class LinearRegression(MuSigmaCBModel): def __init__( - self, feature_dim: int, l2_reg_lambda: float = 1.0, gamma: float = 1.0 + self, + feature_dim: int, + l2_reg_lambda: float = 1.0, + gamma: float = 1.0, + force_pinv: bool = False, ) -> None: """ A linear regression model which can estimate both point prediction and uncertainty @@ -41,6 +45,7 @@ def __init__( super(LinearRegression, self).__init__(feature_dim=feature_dim) self.gamma = gamma self.l2_reg_lambda = l2_reg_lambda + self.force_pinv = force_pinv assert ( gamma > 0 and gamma <= 1 ), f"gamma should be in (0, 1]. Got gamma={gamma} instead" @@ -96,12 +101,26 @@ def append_ones(x: torch.Tensor) -> torch.Tensor: return result @staticmethod - def matrix_inv_fallback_pinv(A: torch.Tensor) -> torch.Tensor: + def pinv(A: torch.Tensor) -> torch.Tensor: + """ + Compute the pseudo inverse of A using torch.linalg.pinv + """ + # first check if A is Hermitian (symmetric A) + A_is_hermitian = torch.allclose(A, A.T, atol=1e-4, rtol=1e-4) + # applying hermitian=True saves about 50% computations + return torch.linalg.pinv( + A, + hermitian=A_is_hermitian, + ).contiguous() + + def matrix_inv_fallback_pinv(self, A: torch.Tensor) -> torch.Tensor: """ Try to apply regular matrix inv. If it fails, fallback to pseudo inverse """ + if self.force_pinv: + return self.pinv(A) try: - inv_A = torch.linalg.inv(A).contiguous() + return torch.linalg.inv(A).contiguous() # pyre-ignore[16]: Module `_C` has no attribute `_LinAlgError`. except torch._C._LinAlgError as e: logger.warning( @@ -109,14 +128,7 @@ def matrix_inv_fallback_pinv(A: torch.Tensor) -> torch.Tensor: e, ) # switch from `inv` to `pinv` - # first check if A is Hermitian (symmetric A) - A_is_hermitian = torch.allclose(A, A.T, atol=1e-4, rtol=1e-4) - # applying hermitian=True saves about 50% computations - inv_A = torch.linalg.pinv( - A, - hermitian=A_is_hermitian, - ).contiguous() - return inv_A + return self.pinv(A) def _validate_train_inputs( self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor] diff --git a/pearl/neural_networks/contextual_bandit/neural_linear_regression.py b/pearl/neural_networks/contextual_bandit/neural_linear_regression.py index 634c76cc..8d3749c7 100644 --- a/pearl/neural_networks/contextual_bandit/neural_linear_regression.py +++ b/pearl/neural_networks/contextual_bandit/neural_linear_regression.py @@ -27,6 +27,7 @@ def __init__( hidden_dims: List[int], # last one is the input dim for linear regression l2_reg_lambda_linear: float = 1.0, gamma: float = 1.0, + force_pinv: bool = False, output_activation_name: str = "linear", use_batch_norm: bool = False, use_layer_norm: bool = False, @@ -73,6 +74,7 @@ def __init__( feature_dim=hidden_dims[-1], l2_reg_lambda=l2_reg_lambda_linear, gamma=gamma, + force_pinv=force_pinv, ) self.output_activation: nn.Module = ACTIVATION_MAP[output_activation_name]() self.linear_layer_e2e = nn.Linear( diff --git a/pearl/policy_learners/contextual_bandits/linear_bandit.py b/pearl/policy_learners/contextual_bandits/linear_bandit.py index ec57e0be..a22902fb 100644 --- a/pearl/policy_learners/contextual_bandits/linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/linear_bandit.py @@ -48,6 +48,7 @@ def __init__( gamma: float = 1.0, apply_discounting_interval: float = 0.0, # discounting will be applied after this many # observations (weighted) are processed. set to 0 to disable + force_pinv: bool = False, training_rounds: int = 100, batch_size: int = 128, action_representation_module: Optional[ActionRepresentationModule] = None, @@ -60,7 +61,10 @@ def __init__( action_representation_module=action_representation_module, ) self.model = LinearRegression( - feature_dim=feature_dim, l2_reg_lambda=l2_reg_lambda, gamma=gamma + feature_dim=feature_dim, + l2_reg_lambda=l2_reg_lambda, + gamma=gamma, + force_pinv=force_pinv, ) self.apply_discounting_interval = apply_discounting_interval self.last_sum_weight_when_discounted = 0.0 diff --git a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py index 0bd734d3..61b63388 100644 --- a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py @@ -76,6 +76,7 @@ def __init__( l2_reg_lambda_linear: float = 1.0, gamma: float = 1.0, apply_discounting_interval: float = 0.0, # set to 0 to disable + force_pinv: bool = False, state_features_only: bool = True, loss_type: str = "mse", # one of the LOSS_TYPES names: [mse, mae, cross_entropy] output_activation_name: str = "linear", @@ -103,6 +104,7 @@ def __init__( hidden_dims=hidden_dims, l2_reg_lambda_linear=l2_reg_lambda_linear, gamma=gamma, + force_pinv=force_pinv, output_activation_name=output_activation_name, use_batch_norm=use_batch_norm, use_layer_norm=use_layer_norm,