Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to force use of pinv in LinUCB #89

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions pearl/neural_networks/contextual_bandit/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@

class LinearRegression(MuSigmaCBModel):
def __init__(
self, feature_dim: int, l2_reg_lambda: float = 1.0, gamma: float = 1.0
self,
feature_dim: int,
l2_reg_lambda: float = 1.0,
gamma: float = 1.0,
force_pinv: bool = False,
) -> None:
"""
A linear regression model which can estimate both point prediction and uncertainty
Expand All @@ -37,10 +41,14 @@ def __init__(
gamma: discounting multiplier (A and b get multiplied by gamma periodically, the period
is controlled by PolicyLearner). We use a simplified implementation of
https://arxiv.org/pdf/1909.09146.pdf
force_pinv: If True, we will always use pseudo inverse to invert the `A` matrix. If False,
we will first try to use regular matrix inversion. If it fails, we will fallback to
pseudo inverse.
"""
super(LinearRegression, self).__init__(feature_dim=feature_dim)
self.gamma = gamma
self.l2_reg_lambda = l2_reg_lambda
self.force_pinv = force_pinv
assert (
gamma > 0 and gamma <= 1
), f"gamma should be in (0, 1]. Got gamma={gamma} instead"
Expand Down Expand Up @@ -96,27 +104,34 @@ def append_ones(x: torch.Tensor) -> torch.Tensor:
return result

@staticmethod
def matrix_inv_fallback_pinv(A: torch.Tensor) -> torch.Tensor:
def pinv(A: torch.Tensor) -> torch.Tensor:
"""
Compute the pseudo inverse of A using torch.linalg.pinv
"""
# first check if A is Hermitian (symmetric A)
A_is_hermitian = torch.allclose(A, A.T, atol=1e-4, rtol=1e-4)
# applying hermitian=True saves about 50% computations
return torch.linalg.pinv(
A,
hermitian=A_is_hermitian,
).contiguous()

def matrix_inv_fallback_pinv(self, A: torch.Tensor) -> torch.Tensor:
"""
Try to apply regular matrix inv. If it fails, fallback to pseudo inverse
"""
if self.force_pinv:
return self.pinv(A)
try:
inv_A = torch.linalg.inv(A).contiguous()
return torch.linalg.inv(A).contiguous()
# pyre-ignore[16]: Module `_C` has no attribute `_LinAlgError`.
except torch._C._LinAlgError as e:
logger.warning(
"Exception raised during A inversion, falling back to pseudo-inverse",
e,
)
# switch from `inv` to `pinv`
# first check if A is Hermitian (symmetric A)
A_is_hermitian = torch.allclose(A, A.T, atol=1e-4, rtol=1e-4)
# applying hermitian=True saves about 50% computations
inv_A = torch.linalg.pinv(
A,
hermitian=A_is_hermitian,
).contiguous()
return inv_A
return self.pinv(A)

def _validate_train_inputs(
self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
hidden_dims: List[int], # last one is the input dim for linear regression
l2_reg_lambda_linear: float = 1.0,
gamma: float = 1.0,
force_pinv: bool = False,
output_activation_name: str = "linear",
use_batch_norm: bool = False,
use_layer_norm: bool = False,
Expand All @@ -46,6 +47,9 @@ def __init__(
hidden_dims: size of hidden layers in the network
l2_reg_lambda_linear: L2 regularization parameter for the linear regression layer
gamma: discounting multiplier for the linear regression layer
force_pinv: If True, we will always use pseudo inverse to invert the `A` matrix. If
False, we will first try to use regular matrix inversion. If it fails, we will
fallback to pseudo inverse.
output_activation_name: output activation function name (see ACTIVATION_MAP)
use_batch_norm: whether to use batch normalization
use_layer_norm: whether to use layer normalization
Expand Down Expand Up @@ -73,6 +77,7 @@ def __init__(
feature_dim=hidden_dims[-1],
l2_reg_lambda=l2_reg_lambda_linear,
gamma=gamma,
force_pinv=force_pinv,
)
self.output_activation: nn.Module = ACTIVATION_MAP[output_activation_name]()
self.linear_layer_e2e = nn.Linear(
Expand Down
6 changes: 5 additions & 1 deletion pearl/policy_learners/contextual_bandits/linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
gamma: float = 1.0,
apply_discounting_interval: float = 0.0, # discounting will be applied after this many
# observations (weighted) are processed. set to 0 to disable
force_pinv: bool = False, # If True, use pseudo inverse instead of regular inverse for `A`
training_rounds: int = 100,
batch_size: int = 128,
action_representation_module: Optional[ActionRepresentationModule] = None,
Expand All @@ -60,7 +61,10 @@ def __init__(
action_representation_module=action_representation_module,
)
self.model = LinearRegression(
feature_dim=feature_dim, l2_reg_lambda=l2_reg_lambda, gamma=gamma
feature_dim=feature_dim,
l2_reg_lambda=l2_reg_lambda,
gamma=gamma,
force_pinv=force_pinv,
)
self.apply_discounting_interval = apply_discounting_interval
self.last_sum_weight_when_discounted = 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
l2_reg_lambda_linear: float = 1.0,
gamma: float = 1.0,
apply_discounting_interval: float = 0.0, # set to 0 to disable
force_pinv: bool = False, # If True, use pseudo inverse instead of regular inverse for `A`
state_features_only: bool = True,
loss_type: str = "mse", # one of the LOSS_TYPES names: [mse, mae, cross_entropy]
output_activation_name: str = "linear",
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
hidden_dims=hidden_dims,
l2_reg_lambda_linear=l2_reg_lambda_linear,
gamma=gamma,
force_pinv=force_pinv,
output_activation_name=output_activation_name,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
Expand Down
Loading