diff --git a/pearl/policy_learners/contextual_bandits/contextual_bandit_base.py b/pearl/policy_learners/contextual_bandits/contextual_bandit_base.py index 0ddf932..fadb89d 100644 --- a/pearl/policy_learners/contextual_bandits/contextual_bandit_base.py +++ b/pearl/policy_learners/contextual_bandits/contextual_bandit_base.py @@ -20,6 +20,7 @@ from pearl.api.action_space import ActionSpace from pearl.api.reward import Value from pearl.history_summarization_modules.history_summarization_module import ( + HistorySummarizationModule, SubjectiveState, ) from pearl.policy_learners.exploration_modules.exploration_module import ( diff --git a/pearl/policy_learners/contextual_bandits/disjoint_bandit.py b/pearl/policy_learners/contextual_bandits/disjoint_bandit.py index 49f32a3..78c8382 100644 --- a/pearl/policy_learners/contextual_bandits/disjoint_bandit.py +++ b/pearl/policy_learners/contextual_bandits/disjoint_bandit.py @@ -14,6 +14,7 @@ from pearl.api.action import Action from pearl.api.action_space import ActionSpace from pearl.history_summarization_modules.history_summarization_module import ( + HistorySummarizationModule, SubjectiveState, ) from pearl.neural_networks.common.utils import ensemble_forward @@ -228,3 +229,11 @@ def get_scores( @property def optimizer(self) -> torch.optim.Optimizer: return self._optimizer + + def set_history_summarization_module( + self, value: HistorySummarizationModule + ) -> None: + # usually, this method would also add the parameters of the history summarization module + # to the optimizer of the bandit, but disjoint bandits do not use a pytorch optimizer. + # Instead, the optimization uses Pearl's own linear regression module. + self._history_summarization_module = value diff --git a/pearl/policy_learners/contextual_bandits/disjoint_linear_bandit.py b/pearl/policy_learners/contextual_bandits/disjoint_linear_bandit.py index a7e5236..0e712ab 100644 --- a/pearl/policy_learners/contextual_bandits/disjoint_linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/disjoint_linear_bandit.py @@ -13,6 +13,7 @@ from pearl.api.action import Action from pearl.history_summarization_modules.history_summarization_module import ( + HistorySummarizationModule, SubjectiveState, ) from pearl.neural_networks.common.utils import ensemble_forward @@ -145,3 +146,11 @@ def get_scores( subjective_state: SubjectiveState, ) -> torch.Tensor: raise NotImplementedError("Implement when necessary") + + def set_history_summarization_module( + self, value: HistorySummarizationModule + ) -> None: + # usually, this method would also add the parameters of the history summarization module + # to the optimizer of the bandit, but disjoint bandits do not use a pytorch optimizer. + # Instead, the optimization uses Pearl's own linear regression module. + self._history_summarization_module = value diff --git a/pearl/policy_learners/contextual_bandits/linear_bandit.py b/pearl/policy_learners/contextual_bandits/linear_bandit.py index 75d23ce..2e72379 100644 --- a/pearl/policy_learners/contextual_bandits/linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/linear_bandit.py @@ -15,6 +15,7 @@ ) from pearl.api.action import Action from pearl.history_summarization_modules.history_summarization_module import ( + HistorySummarizationModule, SubjectiveState, ) from pearl.neural_networks.contextual_bandit.linear_regression import LinearRegression @@ -190,3 +191,10 @@ def get_scores( action_space=action_space, representation=self.model, ).squeeze(-1) + + def set_history_summarization_module( + self, value: HistorySummarizationModule + ) -> None: + # currently linear bandit algorithm does not update + # parameters of the history summarization module + self._history_summarization_module = value diff --git a/pearl/policy_learners/contextual_bandits/neural_bandit.py b/pearl/policy_learners/contextual_bandits/neural_bandit.py index 6999e19..b8a8398 100644 --- a/pearl/policy_learners/contextual_bandits/neural_bandit.py +++ b/pearl/policy_learners/contextual_bandits/neural_bandit.py @@ -18,6 +18,7 @@ from pearl.api.action_space import ActionSpace from pearl.history_summarization_modules.history_summarization_module import ( + HistorySummarizationModule, SubjectiveState, ) from pearl.neural_networks.common.utils import LOSS_TYPES @@ -111,6 +112,12 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: "loss": loss.detach(), } + def set_history_summarization_module( + self, value: HistorySummarizationModule + ) -> None: + self._optimizer.add_param_group({"params": value.parameters()}) + self._history_summarization_module = value + def act( self, subjective_state: SubjectiveState, diff --git a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py index e15c713..5cea847 100644 --- a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py @@ -18,6 +18,7 @@ from pearl.api.action_space import ActionSpace from pearl.history_summarization_modules.history_summarization_module import ( + HistorySummarizationModule, SubjectiveState, ) from pearl.neural_networks.common.utils import LOSS_TYPES @@ -146,6 +147,12 @@ def _maybe_apply_discounting(self) -> None: def optimizer(self) -> torch.optim.Optimizer: return self._optimizer + def set_history_summarization_module( + self, value: HistorySummarizationModule + ) -> None: + self._optimizer.add_param_group({"params": value.parameters()}) + self._history_summarization_module = value + def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: # get scores for logging purpose diff --git a/pearl/policy_learners/policy_learner.py b/pearl/policy_learners/policy_learner.py index 1001471..2aa834e 100644 --- a/pearl/policy_learners/policy_learner.py +++ b/pearl/policy_learners/policy_learner.py @@ -131,10 +131,11 @@ def exploration_module(self, new_exploration_module: ExplorationModule) -> None: def get_action_representation_module(self) -> ActionRepresentationModule: return self._action_representation_module + @abstractmethod def set_history_summarization_module( self, value: HistorySummarizationModule ) -> None: - self._history_summarization_module = value + pass def reset(self, action_space: ActionSpace) -> None: """Resets policy maker for a new episode. Default implementation does nothing.""" diff --git a/pearl/policy_learners/sequential_decision_making/actor_critic_base.py b/pearl/policy_learners/sequential_decision_making/actor_critic_base.py index c02074b..7015a15 100644 --- a/pearl/policy_learners/sequential_decision_making/actor_critic_base.py +++ b/pearl/policy_learners/sequential_decision_making/actor_critic_base.py @@ -72,6 +72,8 @@ def __init__( action_space: Optional[ActionSpace] = None, actor_learning_rate: float = 1e-3, critic_learning_rate: float = 1e-3, + # used only for learnable history summarization module + history_summarization_learning_rate: float = 1e-3, actor_network_type: Type[ActorNetwork] = VanillaActorNetwork, critic_network_type: Union[ Type[ValueNetwork], Type[QValueNetwork] @@ -187,11 +189,23 @@ def __init__( self._critic_target: nn.Module = copy.deepcopy(self._critic) self._discount_factor = discount_factor + self._history_summarization_learning_rate = history_summarization_learning_rate def set_history_summarization_module( self, value: HistorySummarizationModule ) -> None: - self._actor_optimizer.add_param_group({"params": value.parameters()}) + """ + The history summarization module uses its own optimizer. + """ + self._history_summarization_optimizer: optim.Optimizer = optim.AdamW( + [ + { + "params": value.parameters(), + "lr": self._history_summarization_learning_rate, + "amsgrad": True, + } + ] + ) self._history_summarization_module = value def act( @@ -279,6 +293,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: Dict[str, Any]: A dictionary containing the loss reports from the critic and actor updates. These can be useful to track for debugging purposes. """ + self._history_summarization_optimizer.zero_grad() actor_loss = self._actor_loss(batch) self._actor_optimizer.zero_grad() """ @@ -289,25 +304,15 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: After the graph is cleared, critic_loss.backward() fails. """ actor_loss.backward(retain_graph=True) + self._actor_optimizer.step() + report = {"actor_loss": actor_loss.item()} if self._use_critic: self._critic_optimizer.zero_grad() critic_loss = self._critic_loss(batch) - """ - This backward operation needs to happen before the actor_optimizer.step(). - This is because actor_optimizer.step() updates the history summarization neural network - and critic_loss.backward() fails - once parameters involved in critic_loss's computational graph change. - """ critic_loss.backward() - self._actor_optimizer.step() self._critic_optimizer.step() - report = { - "actor_loss": actor_loss.item(), - "critic_loss": critic_loss.item(), - } - else: - self._actor_optimizer.step() - report = {"actor_loss": actor_loss.item()} + report["critic_loss"] = critic_loss.item() + self._history_summarization_optimizer.step() if self._use_critic_target: update_critic_target_network( diff --git a/pearl/policy_learners/sequential_decision_making/ddpg.py b/pearl/policy_learners/sequential_decision_making/ddpg.py index a2f87cb..73be518 100644 --- a/pearl/policy_learners/sequential_decision_making/ddpg.py +++ b/pearl/policy_learners/sequential_decision_making/ddpg.py @@ -55,6 +55,7 @@ def __init__( exploration_module: Optional[ExplorationModule] = None, actor_learning_rate: float = 1e-3, critic_learning_rate: float = 1e-3, + history_summarization_learning_rate: float = 1e-3, actor_network_type: Type[ActorNetwork] = VanillaContinuousActorNetwork, critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork, actor_soft_update_tau: float = 0.005, @@ -73,6 +74,7 @@ def __init__( critic_hidden_dims=critic_hidden_dims, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, + history_summarization_learning_rate=history_summarization_learning_rate, actor_network_type=actor_network_type, critic_network_type=critic_network_type, use_actor_target=True, diff --git a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py index d90d48e..29b5d0e 100644 --- a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py @@ -87,6 +87,7 @@ def __init__( value_critic_learning_rate: float = 1e-3, actor_learning_rate: float = 1e-3, critic_learning_rate: float = 1e-3, + history_summarization_learning_rate: float = 1e-3, critic_soft_update_tau: float = 0.05, discount_factor: float = 0.99, training_rounds: int = 5, @@ -107,6 +108,7 @@ def __init__( critic_hidden_dims=critic_hidden_dims, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, + history_summarization_learning_rate=history_summarization_learning_rate, actor_network_type=actor_network_type, critic_network_type=critic_network_type, use_actor_target=False, @@ -147,16 +149,11 @@ def __init__( amsgrad=True, ) - def set_history_summarization_module( - self, value: HistorySummarizationModule - ) -> None: - self._actor_optimizer.add_param_group({"params": value.parameters()}) - self._history_summarization_module = value - def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: value_loss = self._value_loss(batch) critic_loss = self._critic_loss(batch) actor_loss = self._actor_loss(batch) + self._history_summarization_optimizer.zero_grad() self._value_network_optimizer.zero_grad() self._actor_optimizer.zero_grad() self._critic_optimizer.zero_grad() @@ -165,7 +162,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: self._value_network_optimizer.step() self._actor_optimizer.step() self._critic_optimizer.step() - + self._history_summarization_optimizer.step() # update critic and target Twin networks; update_target_networks( self._critic_target._critic_networks_combined, diff --git a/pearl/policy_learners/sequential_decision_making/ppo.py b/pearl/policy_learners/sequential_decision_making/ppo.py index db0e01a..6e94566 100644 --- a/pearl/policy_learners/sequential_decision_making/ppo.py +++ b/pearl/policy_learners/sequential_decision_making/ppo.py @@ -98,6 +98,7 @@ def __init__( critic_hidden_dims: Optional[List[int]] = None, actor_learning_rate: float = 1e-4, critic_learning_rate: float = 1e-4, + history_summarization_learning_rate: float = 1e-4, exploration_module: Optional[ExplorationModule] = None, actor_network_type: Type[ActorNetwork] = VanillaActorNetwork, critic_network_type: Type[ValueNetwork] = VanillaValueNetwork, @@ -119,6 +120,7 @@ def __init__( critic_hidden_dims=critic_hidden_dims, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, + history_summarization_learning_rate=history_summarization_learning_rate, actor_network_type=actor_network_type, critic_network_type=critic_network_type, use_actor_target=False, diff --git a/pearl/policy_learners/sequential_decision_making/reinforce.py b/pearl/policy_learners/sequential_decision_making/reinforce.py index b591f0f..e3741fd 100644 --- a/pearl/policy_learners/sequential_decision_making/reinforce.py +++ b/pearl/policy_learners/sequential_decision_making/reinforce.py @@ -97,6 +97,7 @@ def __init__( action_space: Optional[ActionSpace] = None, actor_learning_rate: float = 1e-4, critic_learning_rate: float = 1e-4, + history_summarization_learning_rate: float = 1e-4, actor_network_type: Type[ActorNetwork] = VanillaActorNetwork, critic_network_type: Type[ValueNetwork] = VanillaValueNetwork, exploration_module: Optional[ExplorationModule] = None, @@ -115,6 +116,7 @@ def __init__( critic_hidden_dims=critic_hidden_dims, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, + history_summarization_learning_rate=history_summarization_learning_rate, actor_network_type=actor_network_type, critic_network_type=critic_network_type, use_actor_target=False, diff --git a/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py b/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py index 7dc56ab..5a2f2bc 100644 --- a/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py +++ b/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py @@ -59,6 +59,7 @@ def __init__( critic_hidden_dims: Optional[List[int]] = None, actor_learning_rate: float = 1e-4, critic_learning_rate: float = 1e-4, + history_summarization_learning_rate: float = 1e-4, actor_network_type: Type[ActorNetwork] = VanillaActorNetwork, critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork, critic_soft_update_tau: float = 0.005, @@ -78,6 +79,7 @@ def __init__( critic_hidden_dims=critic_hidden_dims, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, + history_summarization_learning_rate=history_summarization_learning_rate, actor_network_type=actor_network_type, critic_network_type=critic_network_type, use_actor_target=False, diff --git a/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py b/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py index bfac055..cb59396 100644 --- a/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py +++ b/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py @@ -52,6 +52,7 @@ def __init__( critic_hidden_dims: Optional[List[int]] = None, actor_learning_rate: float = 1e-3, critic_learning_rate: float = 1e-3, + history_summarization_learning_rate: float = 1e-3, actor_network_type: Type[ActorNetwork] = GaussianActorNetwork, critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork, critic_soft_update_tau: float = 0.005, @@ -72,6 +73,7 @@ def __init__( critic_hidden_dims=critic_hidden_dims, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, + history_summarization_learning_rate=history_summarization_learning_rate, actor_network_type=actor_network_type, critic_network_type=critic_network_type, use_actor_target=False, diff --git a/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py b/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py index 07d78ca..3e532ed 100644 --- a/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py @@ -13,6 +13,7 @@ from pearl.api.action_space import ActionSpace from pearl.api.reward import Reward, Value from pearl.history_summarization_modules.history_summarization_module import ( + HistorySummarizationModule, SubjectiveState, ) from pearl.policy_learners.exploration_modules.common.epsilon_greedy_exploration import ( @@ -75,6 +76,12 @@ def reset(self, action_space: ActionSpace) -> None: f"action.item() == action's index. " ) + def set_history_summarization_module( + self, value: HistorySummarizationModule + ) -> None: + # tabular q learning is assumed to not update parameters of the history summarization module + self._history_summarization_module = value + def act( self, subjective_state: SubjectiveState, diff --git a/pearl/policy_learners/sequential_decision_making/td3.py b/pearl/policy_learners/sequential_decision_making/td3.py index 8f333e7..5a0aaf8 100644 --- a/pearl/policy_learners/sequential_decision_making/td3.py +++ b/pearl/policy_learners/sequential_decision_making/td3.py @@ -104,20 +104,21 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: self._critic_update_count += 1 report = {} # delayed actor update - self._actor_optimizer.zero_grad() + self._history_summarization_optimizer.zero_grad() if self._critic_update_count % self._actor_update_freq == 0: - # see ddpg base class for actor update details + self._actor_optimizer.zero_grad() actor_loss = self._actor_loss(batch) actor_loss.backward(retain_graph=True) + self._actor_optimizer.step() self._last_actor_loss = actor_loss.item() report["actor_loss"] = self._last_actor_loss self._critic_optimizer.zero_grad() critic_loss = self._critic_loss(batch) # critic update critic_loss.backward() - self._actor_optimizer.step() self._critic_optimizer.step() report["critic_loss"] = critic_loss.item() + self._history_summarization_optimizer.step() if self._critic_update_count % self._actor_update_freq == 0: # update targets of critics using soft updates diff --git a/pearl/utils/scripts/benchmark.py b/pearl/utils/scripts/benchmark.py index 2575345..65c5eff 100644 --- a/pearl/utils/scripts/benchmark.py +++ b/pearl/utils/scripts/benchmark.py @@ -164,10 +164,19 @@ def evaluate_single( ): if ( method["history_summarization_module"].__name__ - == "StackHistorySummarizationModule" + == "StackingHistorySummarizationModule" ): + method["history_summarization_module_args"]["observation_dim"] = ( + env.observation_space.shape[0] + ) + method["history_summarization_module_args"]["action_dim"] = ( + policy_learner_args["action_representation_module"].representation_dim + if "action_representation_module" in policy_learner_args + else env.action_space.action_dim + ) policy_learner_args["state_dim"] = ( - env.observation_space.shape[0] + env.action_space.n + method["history_summarization_module_args"]["observation_dim"] + + method["history_summarization_module_args"]["action_dim"] ) * method["history_summarization_module_args"]["history_length"] elif ( method["history_summarization_module"].__name__