Skip to content

Commit

Permalink
Improve LinearBandit and LinearRegression documentation
Browse files Browse the repository at this point in the history
Summary: Improve `LinearBandit` and `LinearRegression` documentation.

Reviewed By: yiwan-rl

Differential Revision: D66345056

fbshipit-source-id: b3316a1b9cac501eb7c518199013c1e2dbe86c6b
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed Nov 26, 2024
1 parent 5cb0d76 commit 3f9ac03
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 22 deletions.
42 changes: 34 additions & 8 deletions pearl/neural_networks/contextual_bandit/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,38 @@ def __init__(
) -> None:
"""
A linear regression model which can estimate both point prediction and uncertainty
(standard delivation).
(standard deviation).
Based on the LinUCB paper: https://arxiv.org/pdf/1003.0146.pdf
Note that instead of being trained by a PyTorch optimizer, we explicitly
update attributes A and b (according to the LinUCB formulas implemented in
learn_batch() method)
Note that instead of being trained by a PyTorch optimizer,
we use the analytical Weight Least Square solution to update the model parameters,
where the regression coefficients are updated in closed form:
coefs = (X^T * X)^-1 * X^T * W * y
where W is an optional weight tensor (e.g. for weighted least squares).
To compute coefficients, we maintain matrix A = X^T * X and vector b = X^T * W * y,
which are updated as new data comes in.
An extra column of ones is appended to the input data for the intercept where necessary.
A user should not append a column of ones to the input data.
A user should not append a column of ones to the input data.
It furthermore allows for _discounting_. This provides the model with the ability
to "forget" old data and adjust to a new data distribution in a non-stationary
environment. The discounting is applied periodically and consists of multiplying
the underlying linear system matrices A and b (the model's weights) by gamma
(the discounting multiplier). The discounting period is controlled by
apply_discounting_interval, which consists of the number of inputs to be
processed between different rounds of discounting. Note that, because inputs
are weighted, apply_discounting_interval is more precisely described as
the sum of weights of inputs that need to be processed before
discounting takes place again. This is expressed in pseudo-code as
```
if apply_discounting_interval > 0 and (
sum_weights - sum_weights_when_last_discounted
>= apply_discounting_interval:
A *= discount factor
b *= discount factor
```
To disable discounting, simply set gamma to 1.
feature_dim: number of features
l2_reg_lambda: L2 regularization parameter
Expand Down Expand Up @@ -192,9 +217,10 @@ def apply_discounting(self) -> None:
A <- A * gamma
b <- b * gamma
"""
logger.info(f"Applying discounting at sum_weight={self._sum_weight}")
self._A *= self.gamma
self._b *= self.gamma
if self.gamma < 1:
logger.info(f"Applying discounting at sum_weight={self._sum_weight}")
self._A *= self.gamma
self._b *= self.gamma
# don't dicount sum_weight because it's used to determine when to apply discounting

self.calculate_coefs() # update coefs using new A and b
Expand Down
53 changes: 39 additions & 14 deletions pearl/policy_learners/contextual_bandits/linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,51 @@ class LinearBandit(ContextualBanditBase):
Policy Learner for Contextual Bandit with Linear Policy.
This class implements a policy learner for a contextual bandit problem where the policy is
linear. It supports learning through linear regression and can apply discounting to observations
based on the number of weighted data points processed. The learner also supports exploration
modules for acting based on learned policies.
linear and learned through linear regression.
See the documentation of LinearRegression for more details on the underlying model.
It furthermore allows for _discounting_. This provides the model with the ability
to "forget" old data and adjust to a new data distribution in a non-stationary
environment. The discounting is applied periodically and consists of multiplying
the underlying linear system matrices A and b (the model's weights) by gamma
(the discounting multiplier). The discounting period is controlled by
apply_discounting_interval, which consists of the number of inputs to be
processed between different rounds of discounting. Note that, because inputs
are weighted, apply_discounting_interval is more precisely described as
the sum of weights of inputs that need to be processed before
discounting takes place again. This is expressed in pseudo-code as
```
if apply_discounting_interval > 0 and (
sum_weights - sum_weights_when_last_discounted
>= apply_discounting_interval:
A *= discount factor
b *= discount factor
```
To disable discounting, simply set gamma to 1.
The learner also supports exploration modules for acting based on learned policies.
Attributes:
model (LinearRegression): Linear regression model used for learning.
apply_discounting_interval (float): Interval for applying discounting to the data points.
last_sum_weight_when_discounted (float): The counter for the last data point when discounting was applied.
last_sum_weight_when_discounted (float): The counter for the last data point
when discounting was applied.
Args:
feature_dim (int): Dimension of the feature space.
exploration_module (Optional[ExplorationModule]): Module for exploring actions.
l2_reg_lambda (float): L2 regularization parameter for the linear regression model.
gamma (float): Discount factor for discounting observations.
apply_discounting_interval (float): number of (weighted observations) for applying discounting to the data points.
Set to 0.0 to disable.
force_pinv (bool): If True, use pseudo-inverse for matrix inversion in the linear model.
training_rounds (int): Number of training rounds.
batch_size (int): Size of the batches used during training.
action_representation_module (Optional[ActionRepresentationModule]): Module for representing actions.
exploration_module (Optional[ExplorationModule]): module for exploring actions.
l2_reg_lambda (float, default 1.0): L2 regularization parameter for the linear
regression model.
gamma (float, default 1.0): the discounting factor.
apply_discounting_interval (float, default 0): number of (weighted) observations for
applying discounting to the data points. Set to 0.0 to disable.
force_pinv (float, default False): if True, we will always use pseudo-inversion to invert
the A matrix. If False, we will first try to use regular
matrix inversion.
If it fails, we will fallback to pseudo-inverse.
training_rounds (int): number of training rounds.
batch_size (int, default 128): size of the batches used during training.
action_representation_module (Optional[ActionRepresentationModule], default identity):
module for representing actions.
"""

def __init__(
Expand Down

0 comments on commit 3f9ac03

Please sign in to comment.