Skip to content

Commit

Permalink
Add option to force use of pinv in LinUCB (#89)
Browse files Browse the repository at this point in the history
Summary:

Adding an options to force pseudo-inverse to be used by default in LinUCB. This is useful for cases where the features are strongly correlated, so the design matrix is ill-conditioned.
The regular inverse is faster, it should be used when matrix inversion is the bottleneck and the design matrix is well-conditioned.

Differential Revision: D57462061
  • Loading branch information
Alex Nikulkov authored and facebook-github-bot committed May 17, 2024
1 parent 7fa88fc commit e4b9e6f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 12 deletions.
37 changes: 26 additions & 11 deletions pearl/neural_networks/contextual_bandit/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@

class LinearRegression(MuSigmaCBModel):
def __init__(
self, feature_dim: int, l2_reg_lambda: float = 1.0, gamma: float = 1.0
self,
feature_dim: int,
l2_reg_lambda: float = 1.0,
gamma: float = 1.0,
force_pinv: bool = False,
) -> None:
"""
A linear regression model which can estimate both point prediction and uncertainty
Expand All @@ -37,10 +41,14 @@ def __init__(
gamma: discounting multiplier (A and b get multiplied by gamma periodically, the period
is controlled by PolicyLearner). We use a simplified implementation of
https://arxiv.org/pdf/1909.09146.pdf
force_pinv: If True, we will always use pseudo inverse to invert the `A` matrix. If False,
we will first try to use regular matrix inversion. If it fails, we will fallback to
pseudo inverse.
"""
super(LinearRegression, self).__init__(feature_dim=feature_dim)
self.gamma = gamma
self.l2_reg_lambda = l2_reg_lambda
self.force_pinv = force_pinv
assert (
gamma > 0 and gamma <= 1
), f"gamma should be in (0, 1]. Got gamma={gamma} instead"
Expand Down Expand Up @@ -96,27 +104,34 @@ def append_ones(x: torch.Tensor) -> torch.Tensor:
return result

@staticmethod
def matrix_inv_fallback_pinv(A: torch.Tensor) -> torch.Tensor:
def pinv(A: torch.Tensor) -> torch.Tensor:
"""
Compute the pseudo inverse of A using torch.linalg.pinv
"""
# first check if A is Hermitian (symmetric A)
A_is_hermitian = torch.allclose(A, A.T, atol=1e-4, rtol=1e-4)
# applying hermitian=True saves about 50% computations
return torch.linalg.pinv(
A,
hermitian=A_is_hermitian,
).contiguous()

def matrix_inv_fallback_pinv(self, A: torch.Tensor) -> torch.Tensor:
"""
Try to apply regular matrix inv. If it fails, fallback to pseudo inverse
"""
if self.force_pinv:
return self.pinv(A)
try:
inv_A = torch.linalg.inv(A).contiguous()
return torch.linalg.inv(A).contiguous()
# pyre-ignore[16]: Module `_C` has no attribute `_LinAlgError`.
except torch._C._LinAlgError as e:
logger.warning(
"Exception raised during A inversion, falling back to pseudo-inverse",
e,
)
# switch from `inv` to `pinv`
# first check if A is Hermitian (symmetric A)
A_is_hermitian = torch.allclose(A, A.T, atol=1e-4, rtol=1e-4)
# applying hermitian=True saves about 50% computations
inv_A = torch.linalg.pinv(
A,
hermitian=A_is_hermitian,
).contiguous()
return inv_A
return self.pinv(A)

def _validate_train_inputs(
self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
hidden_dims: List[int], # last one is the input dim for linear regression
l2_reg_lambda_linear: float = 1.0,
gamma: float = 1.0,
force_pinv: bool = False,
output_activation_name: str = "linear",
use_batch_norm: bool = False,
use_layer_norm: bool = False,
Expand All @@ -46,6 +47,9 @@ def __init__(
hidden_dims: size of hidden layers in the network
l2_reg_lambda_linear: L2 regularization parameter for the linear regression layer
gamma: discounting multiplier for the linear regression layer
force_pinv: If True, we will always use pseudo inverse to invert the `A` matrix. If
False, we will first try to use regular matrix inversion. If it fails, we will
fallback to pseudo inverse.
output_activation_name: output activation function name (see ACTIVATION_MAP)
use_batch_norm: whether to use batch normalization
use_layer_norm: whether to use layer normalization
Expand Down Expand Up @@ -73,6 +77,7 @@ def __init__(
feature_dim=hidden_dims[-1],
l2_reg_lambda=l2_reg_lambda_linear,
gamma=gamma,
force_pinv=force_pinv,
)
self.output_activation: nn.Module = ACTIVATION_MAP[output_activation_name]()
self.linear_layer_e2e = nn.Linear(
Expand Down
6 changes: 5 additions & 1 deletion pearl/policy_learners/contextual_bandits/linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
gamma: float = 1.0,
apply_discounting_interval: float = 0.0, # discounting will be applied after this many
# observations (weighted) are processed. set to 0 to disable
force_pinv: bool = False, # If True, use pseudo inverse instead of regular inverse for `A`
training_rounds: int = 100,
batch_size: int = 128,
action_representation_module: Optional[ActionRepresentationModule] = None,
Expand All @@ -60,7 +61,10 @@ def __init__(
action_representation_module=action_representation_module,
)
self.model = LinearRegression(
feature_dim=feature_dim, l2_reg_lambda=l2_reg_lambda, gamma=gamma
feature_dim=feature_dim,
l2_reg_lambda=l2_reg_lambda,
gamma=gamma,
force_pinv=force_pinv,
)
self.apply_discounting_interval = apply_discounting_interval
self.last_sum_weight_when_discounted = 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
l2_reg_lambda_linear: float = 1.0,
gamma: float = 1.0,
apply_discounting_interval: float = 0.0, # set to 0 to disable
force_pinv: bool = False, # If True, use pseudo inverse instead of regular inverse for `A`
state_features_only: bool = True,
loss_type: str = "mse", # one of the LOSS_TYPES names: [mse, mae, cross_entropy]
output_activation_name: str = "linear",
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
hidden_dims=hidden_dims,
l2_reg_lambda_linear=l2_reg_lambda_linear,
gamma=gamma,
force_pinv=force_pinv,
output_activation_name=output_activation_name,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
Expand Down

0 comments on commit e4b9e6f

Please sign in to comment.