diff --git a/pearl/neural_networks/contextual_bandit/linear_regression.py b/pearl/neural_networks/contextual_bandit/linear_regression.py index 318180f9..f43d4d53 100644 --- a/pearl/neural_networks/contextual_bandit/linear_regression.py +++ b/pearl/neural_networks/contextual_bandit/linear_regression.py @@ -20,7 +20,11 @@ class LinearRegression(MuSigmaCBModel): def __init__( - self, feature_dim: int, l2_reg_lambda: float = 1.0, gamma: float = 1.0 + self, + feature_dim: int, + l2_reg_lambda: float = 1.0, + gamma: float = 1.0, + force_pinv: bool = False, ) -> None: """ A linear regression model which can estimate both point prediction and uncertainty @@ -37,10 +41,14 @@ def __init__( gamma: discounting multiplier (A and b get multiplied by gamma periodically, the period is controlled by PolicyLearner). We use a simplified implementation of https://arxiv.org/pdf/1909.09146.pdf + force_pinv: If True, we will always use pseudo inverse to invert the `A` matrix. If False, + we will first try to use regular matrix inversion. If it fails, we will fallback to + pseudo inverse. """ super(LinearRegression, self).__init__(feature_dim=feature_dim) self.gamma = gamma self.l2_reg_lambda = l2_reg_lambda + self.force_pinv = force_pinv assert ( gamma > 0 and gamma <= 1 ), f"gamma should be in (0, 1]. Got gamma={gamma} instead" @@ -96,12 +104,26 @@ def append_ones(x: torch.Tensor) -> torch.Tensor: return result @staticmethod - def matrix_inv_fallback_pinv(A: torch.Tensor) -> torch.Tensor: + def pinv(A: torch.Tensor) -> torch.Tensor: + """ + Compute the pseudo inverse of A using torch.linalg.pinv + """ + # first check if A is Hermitian (symmetric A) + A_is_hermitian = torch.allclose(A, A.T, atol=1e-4, rtol=1e-4) + # applying hermitian=True saves about 50% computations + return torch.linalg.pinv( + A, + hermitian=A_is_hermitian, + ).contiguous() + + def matrix_inv_fallback_pinv(self, A: torch.Tensor) -> torch.Tensor: """ Try to apply regular matrix inv. If it fails, fallback to pseudo inverse """ + if self.force_pinv: + return self.pinv(A) try: - inv_A = torch.linalg.inv(A).contiguous() + return torch.linalg.inv(A).contiguous() # pyre-ignore[16]: Module `_C` has no attribute `_LinAlgError`. except torch._C._LinAlgError as e: logger.warning( @@ -109,14 +131,7 @@ def matrix_inv_fallback_pinv(A: torch.Tensor) -> torch.Tensor: e, ) # switch from `inv` to `pinv` - # first check if A is Hermitian (symmetric A) - A_is_hermitian = torch.allclose(A, A.T, atol=1e-4, rtol=1e-4) - # applying hermitian=True saves about 50% computations - inv_A = torch.linalg.pinv( - A, - hermitian=A_is_hermitian, - ).contiguous() - return inv_A + return self.pinv(A) def _validate_train_inputs( self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor] diff --git a/pearl/neural_networks/contextual_bandit/neural_linear_regression.py b/pearl/neural_networks/contextual_bandit/neural_linear_regression.py index 634c76cc..a6b5fee1 100644 --- a/pearl/neural_networks/contextual_bandit/neural_linear_regression.py +++ b/pearl/neural_networks/contextual_bandit/neural_linear_regression.py @@ -27,6 +27,7 @@ def __init__( hidden_dims: List[int], # last one is the input dim for linear regression l2_reg_lambda_linear: float = 1.0, gamma: float = 1.0, + force_pinv: bool = False, output_activation_name: str = "linear", use_batch_norm: bool = False, use_layer_norm: bool = False, @@ -46,6 +47,9 @@ def __init__( hidden_dims: size of hidden layers in the network l2_reg_lambda_linear: L2 regularization parameter for the linear regression layer gamma: discounting multiplier for the linear regression layer + force_pinv: If True, we will always use pseudo inverse to invert the `A` matrix. If + False, we will first try to use regular matrix inversion. If it fails, we will + fallback to pseudo inverse. output_activation_name: output activation function name (see ACTIVATION_MAP) use_batch_norm: whether to use batch normalization use_layer_norm: whether to use layer normalization @@ -73,6 +77,7 @@ def __init__( feature_dim=hidden_dims[-1], l2_reg_lambda=l2_reg_lambda_linear, gamma=gamma, + force_pinv=force_pinv, ) self.output_activation: nn.Module = ACTIVATION_MAP[output_activation_name]() self.linear_layer_e2e = nn.Linear( diff --git a/pearl/policy_learners/contextual_bandits/linear_bandit.py b/pearl/policy_learners/contextual_bandits/linear_bandit.py index ec57e0be..a9c385bb 100644 --- a/pearl/policy_learners/contextual_bandits/linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/linear_bandit.py @@ -48,6 +48,7 @@ def __init__( gamma: float = 1.0, apply_discounting_interval: float = 0.0, # discounting will be applied after this many # observations (weighted) are processed. set to 0 to disable + force_pinv: bool = False, # If True, use pseudo inverse instead of regular inverse for `A` training_rounds: int = 100, batch_size: int = 128, action_representation_module: Optional[ActionRepresentationModule] = None, @@ -60,7 +61,10 @@ def __init__( action_representation_module=action_representation_module, ) self.model = LinearRegression( - feature_dim=feature_dim, l2_reg_lambda=l2_reg_lambda, gamma=gamma + feature_dim=feature_dim, + l2_reg_lambda=l2_reg_lambda, + gamma=gamma, + force_pinv=force_pinv, ) self.apply_discounting_interval = apply_discounting_interval self.last_sum_weight_when_discounted = 0.0 diff --git a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py index 0bd734d3..27141acd 100644 --- a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py @@ -76,6 +76,7 @@ def __init__( l2_reg_lambda_linear: float = 1.0, gamma: float = 1.0, apply_discounting_interval: float = 0.0, # set to 0 to disable + force_pinv: bool = False, # If True, use pseudo inverse instead of regular inverse for `A` state_features_only: bool = True, loss_type: str = "mse", # one of the LOSS_TYPES names: [mse, mae, cross_entropy] output_activation_name: str = "linear", @@ -103,6 +104,7 @@ def __init__( hidden_dims=hidden_dims, l2_reg_lambda_linear=l2_reg_lambda_linear, gamma=gamma, + force_pinv=force_pinv, output_activation_name=output_activation_name, use_batch_norm=use_batch_norm, use_layer_norm=use_layer_norm,