From 73890da7bb1dd7d00ec7589a8a0d16845def0bb3 Mon Sep 17 00:00:00 2001 From: Jalaj Bhandari Date: Mon, 26 Feb 2024 15:13:05 -0800 Subject: [PATCH] Changes to Deep Q learning to expose network type and network instance Summary: This diff exposes some important input arguments for dqn and adds documentation. Reviewed By: rodrigodesalvobraz Differential Revision: D53602430 fbshipit-source-id: e953395b8558b3dfa40c5d9c7dd5c931da760e17 --- .../deep_q_learning.py | 96 +++++++++++++++-- .../sequential_decision_making/deep_sarsa.py | 2 +- .../deep_td_learning.py | 101 ++++++++++++++++-- .../sequential_decision_making/double_dqn.py | 2 +- test/integration/test_integration.py | 15 ++- .../with_pytorch/test_deep_td_learning.py | 6 +- 6 files changed, 194 insertions(+), 28 deletions(-) diff --git a/pearl/policy_learners/sequential_decision_making/deep_q_learning.py b/pearl/policy_learners/sequential_decision_making/deep_q_learning.py index 3b7d2114..bc42df65 100644 --- a/pearl/policy_learners/sequential_decision_making/deep_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/deep_q_learning.py @@ -5,13 +5,18 @@ # LICENSE file in the root directory of this source tree. # -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple, Type import torch from pearl.action_representation_modules.action_representation_module import ( ActionRepresentationModule, ) from pearl.api.action_space import ActionSpace + +from pearl.neural_networks.sequential_decision_making.q_value_networks import ( + QValueNetwork, + VanillaQValueNetwork, +) from pearl.policy_learners.exploration_modules.common.epsilon_greedy_exploration import ( EGreedyExploration, ) @@ -33,13 +38,62 @@ class DeepQLearning(DeepTDLearning): def __init__( self, state_dim: int, - learning_rate: float = 0.001, action_space: Optional[ActionSpace] = None, + hidden_dims: Optional[List[int]] = None, exploration_module: Optional[ExplorationModule] = None, - soft_update_tau: float = 1.0, # no soft update + learning_rate: float = 0.001, + discount_factor: float = 0.99, + training_rounds: int = 10, + batch_size: int = 128, + target_update_freq: int = 10, + soft_update_tau: float = 0.75, # a value of 1 indicates no soft updates + is_conservative: bool = False, + conservative_alpha: Optional[float] = 2.0, + network_type: Type[QValueNetwork] = VanillaQValueNetwork, action_representation_module: Optional[ActionRepresentationModule] = None, + network_instance: Optional[QValueNetwork] = None, **kwargs: Any, ) -> None: + """Constructs a DeepQLearning policy learner. DeepQLearning is based on DeepTDLearning + class and uses `act` and `learn_batch` methods of that class. We only implement the + `get_next_state_values` function to compute the bellman targets using Q-learning. + + Args: + state_dim: Dimension of the observation space. + action_space (ActionSpace, optional): Action space of the problem. It is kept optional + to allow for the use of dynamic action spaces (both `learn_batch` and `act` + functions). Defaults to None. + hidden_dims (List[int], optional): Hidden dimensions of the default `QValueNetwork` + (taken to be `VanillaQValueNetwork`). Defaults to None. + exploration_module (ExplorationModule, optional): Optional exploration module to + trade-off between exploitation and exploration. Defaults to None. + learning_rate (float): Learning rate for AdamW optimizer. Defaults to 0.001. + Note: We use AdamW by default for all value based methods. + discount_factor (float): Discount factor for TD updates. Defaults to 0.99. + training_rounds (int): Number of gradient updates per environment step. + Defaults to 10. + batch_size (int): Sample size for mini-batch gradient updates. Defaults to 128. + target_update_freq (int): Frequency at which the target network is updated. + Defaults to 10. + soft_update_tau (float): Coefficient for soft updates to the target networks. + Defaults to 0.01. + is_conservative (bool): Whether to use conservative updates for offline learning + with conservative Q-learning (CQL). Defaults to False. + conservative_alpha (float, optional): Alpha parameter for CQL. Defaults to 2.0. + network_type (Type[QValueNetwork]): Network type for the Q-value network. Defaults to + `VanillaQValueNetwork`. This means that by default, an instance of the class + `VanillaQValueNetwork` (or the specified `network_type` class) is created and used + for learning. + action_representation_module (ActionRepresentationModule, optional): Optional module to + represent actions as a feature vector. Typically specified at the agent level. + Defaults to None. + network_instance (QValueNetwork, optional): A network instance to be used as the + Q-value network. Defaults to None. + Note: This is an alternative to specifying a `network_type`. If provided, the + specified `network_type` is ignored and the input `network_instance` is used for + learning. Allows for custom implementations of Q-value networks. + """ + super(DeepQLearning, self).__init__( exploration_module=exploration_module if exploration_module is not None @@ -47,24 +101,46 @@ def __init__( on_policy=False, state_dim=state_dim, action_space=action_space, + hidden_dims=hidden_dims, learning_rate=learning_rate, soft_update_tau=soft_update_tau, + network_type=network_type, action_representation_module=action_representation_module, + discount_factor=discount_factor, + training_rounds=training_rounds, + batch_size=batch_size, + network_instance=network_instance, **kwargs, ) @torch.no_grad() - def _get_next_state_values( + def get_next_state_values( self, batch: TransitionBatch, batch_size: int ) -> torch.Tensor: + """ + Computes the maximum Q-value over all available actions in the next state using the target + network. Note: Q-learning is designed to work with discrete action spaces. + + Args: + batch (TransitionBatch): Batch of transitions. For Q learning, any transtion must have + the 'next_state', 'next_available_actions' and the 'next_unavailable_actions_mask' + fields set. The 'next_available_actions' and 'next_unavailable_actions_mask' fields + implement dynamic actions spaces in Pearl. + batch_size (int): Size of the batch. + + Returns: + torch.Tensor: Maximum Q-value over all available actions in the next state. + """ + ( - next_state, - next_available_actions, - next_unavailable_actions_mask, + next_state, # (batch_size x action_space_size x state_dim) + next_available_actions, # (batch_size x action_space_size x action_dim) + next_unavailable_actions_mask, # (batch_size x action_space_size) ) = self._prepare_next_state_action_batch(batch) assert next_available_actions is not None + # Get Q values for each (state, action), where action \in {available_actions} next_state_action_values = self._Q_target.get_q_values( next_state, next_available_actions ).view(batch_size, -1) @@ -79,6 +155,12 @@ def _get_next_state_values( def _prepare_next_state_action_batch( self, batch: TransitionBatch ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + + # This function outputs tensors: + # - next_state_batch: (batch_size x action_space_size x state_dim). + # - next_available_actions_batch: (batch_size x action_space_size x action_dim). + # - next_unavailable_actions_mask_batch: (batch_size x action_space_size). + next_state_batch = batch.next_state # (batch_size x state_dim) assert next_state_batch is not None diff --git a/pearl/policy_learners/sequential_decision_making/deep_sarsa.py b/pearl/policy_learners/sequential_decision_making/deep_sarsa.py index d56bea2c..dc95fda5 100644 --- a/pearl/policy_learners/sequential_decision_making/deep_sarsa.py +++ b/pearl/policy_learners/sequential_decision_making/deep_sarsa.py @@ -50,7 +50,7 @@ def __init__( ) @torch.no_grad() - def _get_next_state_values( + def get_next_state_values( self, batch: TransitionBatch, batch_size: int ) -> torch.Tensor: """ diff --git a/pearl/policy_learners/sequential_decision_making/deep_td_learning.py b/pearl/policy_learners/sequential_decision_making/deep_td_learning.py index f93d3b71..a2002d58 100644 --- a/pearl/policy_learners/sequential_decision_making/deep_td_learning.py +++ b/pearl/policy_learners/sequential_decision_making/deep_td_learning.py @@ -44,7 +44,7 @@ # of production stack on this param. class DeepTDLearning(PolicyLearner): """ - An Abstract Class for Deep Temporal Difference learning policy learner. + An Abstract Class for Deep Temporal Difference learning. """ def __init__( @@ -71,6 +71,46 @@ def __init__( action_representation_module: Optional[ActionRepresentationModule] = None, **kwargs: Any, ) -> None: + """Constructs a DeepTDLearning based policy learner. DeepTDLearning is the base class + for all value based (i.e. temporal difference learning based) algorithms. + + Args: + state_dim: Dimension of the state space. + exploration_module (ExplorationModule, optional): Optional exploration module used by + the `act` function to trade-off between exploitation and exploration. + Defaults to None. + action_space (ActionSpace, optional): Action space of the problem. It is kept optional + to allow for the use of dynamic action spaces (see `learn_batch` and `act` + functions). Defaults to None. + hidden_dims (List[int], optional): Hidden dimensions of the default `QValueNetwork` + (taken to be `VanillaQValueNetwork`). Defaults to None. + learning_rate (float): Learning rate for the optimizer. Defaults to 0.001. + Note: We use AdamW as default for all value based methods. + discount_factor (float): Discount factor for TD updates. Defaults to 0.99. + training_rounds (int): Number of gradient updates per environment step. + Defaults to 100. + batch_size (int): Sample size for mini-batch gradient updates. Defaults to 128. + target_update_freq (int): Frequency at which the target network is updated. + Defaults to 100. + soft_update_tau (float): Coefficient for soft updates to the target networks. + Defaults to 0.01. + is_conservative (bool): Whether to use conservative updates for offline learning. + Defaults to False. + conservative_alpha (float, optional): Alpha parameter for conservative updates. + Defaults to 2.0. + network_type (Type[QValueNetwork]): Network type for the Q-value network. Defaults to + `VanillaQValueNetwork`. This means that by default, an instance of the class + `VanillaQValueNetwork` (or the specified `network_type` class) is created and used + for learning. + network_instance (QValueNetwork, optional): A network instance to be used as the + Q-value network. Defaults to None. + Note: This is an alternative to specifying a `network_type`. If provided, the + specified `network_type` is ignored and the input `network_instance` is used for + learning. Allows for custom implementations of Q-value networks. + action_representation_module (ActionRepresentationModule, optional): Optional module to + represent actions as a feature vector. Typically specified at the agent level. + Defaults to None. + """ super(DeepTDLearning, self).__init__( training_rounds=training_rounds, batch_size=batch_size, @@ -142,6 +182,23 @@ def act( available_action_space: ActionSpace, exploit: bool = False, ) -> Action: + """ + Selects an action from the available action space balancing between exploration and + exploitation. + This action can be (i) an 'exploit action', i.e. the optimal action given estimate of the + Q values or (ii) an 'exploratory action' obtained using the specified `exploration_module`. + + Args: + subjective_state (SubjectiveState): Current subjective state. + available_action_space (ActionSpace): Available action space at the current state. + Note that Pearl allows for action spaces to change dynamically. + exploit (bool): When set to True, we output the exploit action (no exploration). + When set to False, the specified `exploration_module` is used to balance + between exploration and exploitation. Defaults to False. + + Returns: + Action: An action from the available action space. + """ # TODO: Assumes gym action space. # Fix the available action space. assert isinstance(available_action_space, DiscreteActionSpace) @@ -167,23 +224,40 @@ def act( if exploit: return exploit_action + assert self._exploration_module is not None return self._exploration_module.act( - subjective_state, - available_action_space, - exploit_action, - q_values, + subjective_state=subjective_state, + action_space=available_action_space, + exploit_action=exploit_action, + values=q_values, ) @abstractmethod - def _get_next_state_values( + def get_next_state_values( self, batch: TransitionBatch, batch_size: int ) -> torch.Tensor: + """ + For a given batch of transitions, returns Q-value targets for the Bellman equation. + Child classes should implement this method. + + For example, this method in DQN returns + "max_{action in available_action_space} Q(next_state, action)". + """ pass def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: + """ + Batch learning with TD(0) style updates. Different implementations of the + `get_next_state_values` function correspond to the different RL algorithm implementations, + for example TD learning, DQN, Double DQN, Duelling DQN etc. + + Args: + batch (TransitionBatch): batch of transitions + Returns: + Dict[str, Any]: dictionary with loss as the mean bellman error (across the batch). + """ state_batch = batch.state # (batch_size x state_dim) - action_batch = batch.action - # (batch_size x action_dim) + action_batch = batch.action # (batch_size x action_dim) reward_batch = batch.reward # (batch_size) done_batch = batch.done # (batch_size) @@ -191,21 +265,26 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: # sanity check they have same batch_size assert reward_batch.shape[0] == batch_size assert done_batch.shape[0] == batch_size + state_action_values = self._Q.get_q_values( state_batch=state_batch, action_batch=action_batch, curr_available_actions_batch=batch.curr_available_actions, - ) # for duelling, this takes care of the mean subtraction for advantage estimation + ) + # for duelling dqn, specifying the `curr_available_actions_batch` field takes care of + # the mean subtraction for advantage estimation # Compute the Bellman Target expected_state_action_values = ( - self._get_next_state_values(batch, batch_size) + self.get_next_state_values(batch, batch_size) * self._discount_factor * (1 - done_batch.float()) ) + reward_batch # (batch_size), r + gamma * V(s) criterion = torch.nn.MSELoss() bellman_loss = criterion(state_action_values, expected_state_action_values) + + # Conservative TD updates for offline learning. if self._is_conservative: cql_loss = compute_cql_loss(self._Q, batch, batch_size) loss = self._conservative_alpha * cql_loss + bellman_loss @@ -217,7 +296,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: loss.backward() self._optimizer.step() - # Target Network Update + # Target network update if (self._training_steps + 1) % self._target_update_freq == 0: update_target_network(self._Q_target, self._Q, self._soft_update_tau) diff --git a/pearl/policy_learners/sequential_decision_making/double_dqn.py b/pearl/policy_learners/sequential_decision_making/double_dqn.py index ead40314..58c5489d 100644 --- a/pearl/policy_learners/sequential_decision_making/double_dqn.py +++ b/pearl/policy_learners/sequential_decision_making/double_dqn.py @@ -23,7 +23,7 @@ class DoubleDQN(DeepQLearning): """ @torch.no_grad() - def _get_next_state_values( + def get_next_state_values( self, batch: TransitionBatch, batch_size: int ) -> torch.Tensor: next_state_batch = batch.next_state # (batch_size x state_dim) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 17be9607..eaf73a33 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -265,17 +265,22 @@ def test_dueling_dqn( env = GymEnvironment("CartPole-v1") assert isinstance(env.action_space, DiscreteActionSpace) num_actions = env.action_space.n + # We use a one hot representation for representing actions. So take + # action_dim = num_actions. q_network = DuelingQValueNetwork( - state_dim=env.observation_space.shape[0], - action_dim=num_actions, - hidden_dims=[64], - output_dim=1, + state_dim=env.observation_space.shape[ + 0 + ], # dimension of state representation + action_dim=num_actions, # dimension of the action representation + hidden_dims=[64, 64], # dimensions of the intermediate layers + output_dim=1, # set to 1 (Q values are scalars) ) agent = PearlAgent( policy_learner=DeepQLearning( state_dim=env.observation_space.shape[0], action_space=env.action_space, - training_rounds=20, + training_rounds=10, + soft_update_tau=0.75, network_instance=q_network, batch_size=batch_size, action_representation_module=OneHotActionTensorRepresentationModule( diff --git a/test/unit/with_pytorch/test_deep_td_learning.py b/test/unit/with_pytorch/test_deep_td_learning.py index 58e629af..3e22da67 100644 --- a/test/unit/with_pytorch/test_deep_td_learning.py +++ b/test/unit/with_pytorch/test_deep_td_learning.py @@ -79,11 +79,11 @@ def test_double_dqn(self) -> None: batch2 = dqn.preprocess_batch(copy.deepcopy(self.batch)) double_dqn._Q.apply(init_weights) double_dqn._Q_target.apply(init_weights) - double_value = double_dqn._get_next_state_values(batch1, self.batch_size) + double_value = double_dqn.get_next_state_values(batch1, self.batch_size) dqn._Q.load_state_dict(double_dqn._Q.state_dict()) dqn._Q_target.load_state_dict(double_dqn._Q_target.state_dict()) - vanilla_value = dqn._get_next_state_values(batch2, self.batch_size) + vanilla_value = dqn.get_next_state_values(batch2, self.batch_size) self.assertEqual(double_value.shape, vanilla_value.shape) differ = torch.any(double_value != vanilla_value) if differ: @@ -99,7 +99,7 @@ def test_sarsa(self) -> None: exploration_module=EGreedyExploration(0.05), action_representation_module=self.action_representation_module, ) - sa_value = sarsa._get_next_state_values( + sa_value = sarsa.get_next_state_values( batch=sarsa.preprocess_batch(self.batch), batch_size=self.batch_size ) self.assertEqual(sa_value.shape, (self.batch_size,))