diff --git a/pearl/neural_networks/contextual_bandit/neural_linear_regression.py b/pearl/neural_networks/contextual_bandit/neural_linear_regression.py index 5c4329ca..b81a91ab 100644 --- a/pearl/neural_networks/contextual_bandit/neural_linear_regression.py +++ b/pearl/neural_networks/contextual_bandit/neural_linear_regression.py @@ -34,6 +34,7 @@ def __init__( last_activation: Optional[str] = None, dropout_ratio: float = 0.0, use_skip_connections: bool = True, + nn_e2e: bool = True, ) -> None: """ A model for Neural LinUCB (can also be used for Neural LinTS). @@ -52,6 +53,9 @@ def __init__( last_activation: activation function for the last layer dropout_ratio: dropout ratio use_skip_connections: whether to use skip connections + nn_e2e: If True, we use a Linear NN layer to generate mu instead of getting it from + LinUCB. This can improve learning stability. Sigma is still generated from LinUCB. + """ super(NeuralLinearRegression, self).__init__(feature_dim=feature_dim) self._nn_layers = VanillaValueNetwork( @@ -73,6 +77,10 @@ def __init__( self.output_activation: Union[ LeakyReLU, ReLU, Sigmoid, Softplus, Tanh, nn.Identity ] = ACTIVATION_MAP[output_activation_name]() + self.linear_layer_e2e = nn.Linear( + in_features=hidden_dims[-1], out_features=1, bias=False + ) # used only if nn_e2e is True + self.nn_e2e = nn_e2e def forward(self, x: torch.Tensor) -> torch.Tensor: # x can be [batch_size, feature_dim] or [batch_size, num_arms, feature_dim] @@ -84,7 +92,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # dim: [batch_size * num_arms, 1] x = self._nn_layers(x) # apply NN layers - x = self._linear_regression_layer(x) # apply linear regression to NN output + + if self.nn_e2e: + # get mu from end-to-end NN + x = self.linear_layer_e2e(x) # apply linear layer to NN output + else: + # get mu from LinUCB + x = self._linear_regression_layer(x) # apply linear regression to NN output + x = self.output_activation(x) # apply output activation # dim: [batch_size, num_arms] @@ -120,7 +135,14 @@ def forward_with_intermediate_values( nn_output = self._nn_layers(x) # dim: [batch_size * num_arms, 1] - x = self._linear_regression_layer(nn_output) + if self.nn_e2e: + # get mu from end-to-end NN + x = self.linear_layer_e2e(nn_output) # apply linear layer to NN output + else: + # get mu from LinUCB + x = self._linear_regression_layer( + nn_output + ) # apply linear regression to NN output # dim: [batch_size, num_arms] return { diff --git a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py index b2174550..1052eea9 100644 --- a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py @@ -73,6 +73,7 @@ def __init__( last_activation: Optional[str] = None, dropout_ratio: float = 0.0, use_skip_connections: bool = False, + nn_e2e: bool = True, ) -> None: assert ( len(hidden_dims) >= 1 @@ -96,6 +97,7 @@ def __init__( last_activation=last_activation, dropout_ratio=dropout_ratio, use_skip_connections=use_skip_connections, + nn_e2e=nn_e2e, ) self._optimizer: torch.optim.Optimizer = optim.AdamW( self.model.parameters(), lr=learning_rate, amsgrad=True diff --git a/test/unit/with_pytorch/test_disjoint_bandits.py b/test/unit/with_pytorch/test_disjoint_bandits.py index 3fe8b710..dbc32548 100644 --- a/test/unit/with_pytorch/test_disjoint_bandits.py +++ b/test/unit/with_pytorch/test_disjoint_bandits.py @@ -191,7 +191,10 @@ def test_ucb_action_vector(self) -> None: @parameterized_class( ("bandit_class", "bandit_kwargs"), - [(LinearBandit, {}), (NeuralLinearBandit, {"hidden_dims": [20]})], + [ + (LinearBandit, {}), + (NeuralLinearBandit, {"hidden_dims": [20], "learning_rate": 3e-3}), + ], ) class TestDisjointBanditContainerBandits(unittest.TestCase): def setUp(self) -> None: