From 791a5b3837eff557e6b85b409e26eb91ee04ff83 Mon Sep 17 00:00:00 2001 From: Alex Nikulkov Date: Mon, 5 Feb 2024 18:55:06 -0800 Subject: [PATCH] Add nn_e2e to Pearl Neural LinUCB Summary: Migrating the end-to-end logic from reagent to pearl. End-to-end means that the `mu` part of UCB is computed by gradient propagation through an end-to-end Neural Network. This is achieved by having 2 separate output layers - LinUCB regression layer and a vanilla `torch.nn.Linear` layer. LinUCB layer is trained by the LinUCB update logic, while the `torch.nn.Linear` layer is trained by gradient descent. At inference time we use `torch.nn.Linear` to generate `mu`, but use `LinUCB` layer to generate `sigma` part of the UCB score. Reviewed By: BerenLuthien Differential Revision: D53251658 fbshipit-source-id: b7c0e964dfcc1ebdc55561bf9c5dfc872bad8685 --- .../neural_linear_regression.py | 26 +++++++++++++++++-- .../neural_linear_bandit.py | 2 ++ .../with_pytorch/test_disjoint_bandits.py | 5 +++- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/pearl/neural_networks/contextual_bandit/neural_linear_regression.py b/pearl/neural_networks/contextual_bandit/neural_linear_regression.py index 5c4329ca..b81a91ab 100644 --- a/pearl/neural_networks/contextual_bandit/neural_linear_regression.py +++ b/pearl/neural_networks/contextual_bandit/neural_linear_regression.py @@ -34,6 +34,7 @@ def __init__( last_activation: Optional[str] = None, dropout_ratio: float = 0.0, use_skip_connections: bool = True, + nn_e2e: bool = True, ) -> None: """ A model for Neural LinUCB (can also be used for Neural LinTS). @@ -52,6 +53,9 @@ def __init__( last_activation: activation function for the last layer dropout_ratio: dropout ratio use_skip_connections: whether to use skip connections + nn_e2e: If True, we use a Linear NN layer to generate mu instead of getting it from + LinUCB. This can improve learning stability. Sigma is still generated from LinUCB. + """ super(NeuralLinearRegression, self).__init__(feature_dim=feature_dim) self._nn_layers = VanillaValueNetwork( @@ -73,6 +77,10 @@ def __init__( self.output_activation: Union[ LeakyReLU, ReLU, Sigmoid, Softplus, Tanh, nn.Identity ] = ACTIVATION_MAP[output_activation_name]() + self.linear_layer_e2e = nn.Linear( + in_features=hidden_dims[-1], out_features=1, bias=False + ) # used only if nn_e2e is True + self.nn_e2e = nn_e2e def forward(self, x: torch.Tensor) -> torch.Tensor: # x can be [batch_size, feature_dim] or [batch_size, num_arms, feature_dim] @@ -84,7 +92,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # dim: [batch_size * num_arms, 1] x = self._nn_layers(x) # apply NN layers - x = self._linear_regression_layer(x) # apply linear regression to NN output + + if self.nn_e2e: + # get mu from end-to-end NN + x = self.linear_layer_e2e(x) # apply linear layer to NN output + else: + # get mu from LinUCB + x = self._linear_regression_layer(x) # apply linear regression to NN output + x = self.output_activation(x) # apply output activation # dim: [batch_size, num_arms] @@ -120,7 +135,14 @@ def forward_with_intermediate_values( nn_output = self._nn_layers(x) # dim: [batch_size * num_arms, 1] - x = self._linear_regression_layer(nn_output) + if self.nn_e2e: + # get mu from end-to-end NN + x = self.linear_layer_e2e(nn_output) # apply linear layer to NN output + else: + # get mu from LinUCB + x = self._linear_regression_layer( + nn_output + ) # apply linear regression to NN output # dim: [batch_size, num_arms] return { diff --git a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py index b2174550..1052eea9 100644 --- a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py @@ -73,6 +73,7 @@ def __init__( last_activation: Optional[str] = None, dropout_ratio: float = 0.0, use_skip_connections: bool = False, + nn_e2e: bool = True, ) -> None: assert ( len(hidden_dims) >= 1 @@ -96,6 +97,7 @@ def __init__( last_activation=last_activation, dropout_ratio=dropout_ratio, use_skip_connections=use_skip_connections, + nn_e2e=nn_e2e, ) self._optimizer: torch.optim.Optimizer = optim.AdamW( self.model.parameters(), lr=learning_rate, amsgrad=True diff --git a/test/unit/with_pytorch/test_disjoint_bandits.py b/test/unit/with_pytorch/test_disjoint_bandits.py index 3fe8b710..dbc32548 100644 --- a/test/unit/with_pytorch/test_disjoint_bandits.py +++ b/test/unit/with_pytorch/test_disjoint_bandits.py @@ -191,7 +191,10 @@ def test_ucb_action_vector(self) -> None: @parameterized_class( ("bandit_class", "bandit_kwargs"), - [(LinearBandit, {}), (NeuralLinearBandit, {"hidden_dims": [20]})], + [ + (LinearBandit, {}), + (NeuralLinearBandit, {"hidden_dims": [20], "learning_rate": 3e-3}), + ], ) class TestDisjointBanditContainerBandits(unittest.TestCase): def setUp(self) -> None: