Skip to content

Commit

Permalink
Changes to Deep Q learning to expose network type and network instance
Browse files Browse the repository at this point in the history
Summary: This diff exposes some important input arguments for dqn and adds documentation.

Reviewed By: rodrigodesalvobraz

Differential Revision: D53602430

fbshipit-source-id: e953395b8558b3dfa40c5d9c7dd5c931da760e17
  • Loading branch information
jb3618 authored and facebook-github-bot committed Feb 26, 2024
1 parent b3b6005 commit 73890da
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
# LICENSE file in the root directory of this source tree.
#

from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple, Type

import torch
from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
)
from pearl.api.action_space import ActionSpace

from pearl.neural_networks.sequential_decision_making.q_value_networks import (
QValueNetwork,
VanillaQValueNetwork,
)
from pearl.policy_learners.exploration_modules.common.epsilon_greedy_exploration import (
EGreedyExploration,
)
Expand All @@ -33,38 +38,109 @@ class DeepQLearning(DeepTDLearning):
def __init__(
self,
state_dim: int,
learning_rate: float = 0.001,
action_space: Optional[ActionSpace] = None,
hidden_dims: Optional[List[int]] = None,
exploration_module: Optional[ExplorationModule] = None,
soft_update_tau: float = 1.0, # no soft update
learning_rate: float = 0.001,
discount_factor: float = 0.99,
training_rounds: int = 10,
batch_size: int = 128,
target_update_freq: int = 10,
soft_update_tau: float = 0.75, # a value of 1 indicates no soft updates
is_conservative: bool = False,
conservative_alpha: Optional[float] = 2.0,
network_type: Type[QValueNetwork] = VanillaQValueNetwork,
action_representation_module: Optional[ActionRepresentationModule] = None,
network_instance: Optional[QValueNetwork] = None,
**kwargs: Any,
) -> None:
"""Constructs a DeepQLearning policy learner. DeepQLearning is based on DeepTDLearning
class and uses `act` and `learn_batch` methods of that class. We only implement the
`get_next_state_values` function to compute the bellman targets using Q-learning.
Args:
state_dim: Dimension of the observation space.
action_space (ActionSpace, optional): Action space of the problem. It is kept optional
to allow for the use of dynamic action spaces (both `learn_batch` and `act`
functions). Defaults to None.
hidden_dims (List[int], optional): Hidden dimensions of the default `QValueNetwork`
(taken to be `VanillaQValueNetwork`). Defaults to None.
exploration_module (ExplorationModule, optional): Optional exploration module to
trade-off between exploitation and exploration. Defaults to None.
learning_rate (float): Learning rate for AdamW optimizer. Defaults to 0.001.
Note: We use AdamW by default for all value based methods.
discount_factor (float): Discount factor for TD updates. Defaults to 0.99.
training_rounds (int): Number of gradient updates per environment step.
Defaults to 10.
batch_size (int): Sample size for mini-batch gradient updates. Defaults to 128.
target_update_freq (int): Frequency at which the target network is updated.
Defaults to 10.
soft_update_tau (float): Coefficient for soft updates to the target networks.
Defaults to 0.01.
is_conservative (bool): Whether to use conservative updates for offline learning
with conservative Q-learning (CQL). Defaults to False.
conservative_alpha (float, optional): Alpha parameter for CQL. Defaults to 2.0.
network_type (Type[QValueNetwork]): Network type for the Q-value network. Defaults to
`VanillaQValueNetwork`. This means that by default, an instance of the class
`VanillaQValueNetwork` (or the specified `network_type` class) is created and used
for learning.
action_representation_module (ActionRepresentationModule, optional): Optional module to
represent actions as a feature vector. Typically specified at the agent level.
Defaults to None.
network_instance (QValueNetwork, optional): A network instance to be used as the
Q-value network. Defaults to None.
Note: This is an alternative to specifying a `network_type`. If provided, the
specified `network_type` is ignored and the input `network_instance` is used for
learning. Allows for custom implementations of Q-value networks.
"""

super(DeepQLearning, self).__init__(
exploration_module=exploration_module
if exploration_module is not None
else EGreedyExploration(0.05),
on_policy=False,
state_dim=state_dim,
action_space=action_space,
hidden_dims=hidden_dims,
learning_rate=learning_rate,
soft_update_tau=soft_update_tau,
network_type=network_type,
action_representation_module=action_representation_module,
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
network_instance=network_instance,
**kwargs,
)

@torch.no_grad()
def _get_next_state_values(
def get_next_state_values(
self, batch: TransitionBatch, batch_size: int
) -> torch.Tensor:
"""
Computes the maximum Q-value over all available actions in the next state using the target
network. Note: Q-learning is designed to work with discrete action spaces.
Args:
batch (TransitionBatch): Batch of transitions. For Q learning, any transtion must have
the 'next_state', 'next_available_actions' and the 'next_unavailable_actions_mask'
fields set. The 'next_available_actions' and 'next_unavailable_actions_mask' fields
implement dynamic actions spaces in Pearl.
batch_size (int): Size of the batch.
Returns:
torch.Tensor: Maximum Q-value over all available actions in the next state.
"""

(
next_state,
next_available_actions,
next_unavailable_actions_mask,
next_state, # (batch_size x action_space_size x state_dim)
next_available_actions, # (batch_size x action_space_size x action_dim)
next_unavailable_actions_mask, # (batch_size x action_space_size)
) = self._prepare_next_state_action_batch(batch)

assert next_available_actions is not None

# Get Q values for each (state, action), where action \in {available_actions}
next_state_action_values = self._Q_target.get_q_values(
next_state, next_available_actions
).view(batch_size, -1)
Expand All @@ -79,6 +155,12 @@ def _get_next_state_values(
def _prepare_next_state_action_batch(
self, batch: TransitionBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:

# This function outputs tensors:
# - next_state_batch: (batch_size x action_space_size x state_dim).
# - next_available_actions_batch: (batch_size x action_space_size x action_dim).
# - next_unavailable_actions_mask_batch: (batch_size x action_space_size).

next_state_batch = batch.next_state # (batch_size x state_dim)
assert next_state_batch is not None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
)

@torch.no_grad()
def _get_next_state_values(
def get_next_state_values(
self, batch: TransitionBatch, batch_size: int
) -> torch.Tensor:
"""
Expand Down
101 changes: 90 additions & 11 deletions pearl/policy_learners/sequential_decision_making/deep_td_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
# of production stack on this param.
class DeepTDLearning(PolicyLearner):
"""
An Abstract Class for Deep Temporal Difference learning policy learner.
An Abstract Class for Deep Temporal Difference learning.
"""

def __init__(
Expand All @@ -71,6 +71,46 @@ def __init__(
action_representation_module: Optional[ActionRepresentationModule] = None,
**kwargs: Any,
) -> None:
"""Constructs a DeepTDLearning based policy learner. DeepTDLearning is the base class
for all value based (i.e. temporal difference learning based) algorithms.
Args:
state_dim: Dimension of the state space.
exploration_module (ExplorationModule, optional): Optional exploration module used by
the `act` function to trade-off between exploitation and exploration.
Defaults to None.
action_space (ActionSpace, optional): Action space of the problem. It is kept optional
to allow for the use of dynamic action spaces (see `learn_batch` and `act`
functions). Defaults to None.
hidden_dims (List[int], optional): Hidden dimensions of the default `QValueNetwork`
(taken to be `VanillaQValueNetwork`). Defaults to None.
learning_rate (float): Learning rate for the optimizer. Defaults to 0.001.
Note: We use AdamW as default for all value based methods.
discount_factor (float): Discount factor for TD updates. Defaults to 0.99.
training_rounds (int): Number of gradient updates per environment step.
Defaults to 100.
batch_size (int): Sample size for mini-batch gradient updates. Defaults to 128.
target_update_freq (int): Frequency at which the target network is updated.
Defaults to 100.
soft_update_tau (float): Coefficient for soft updates to the target networks.
Defaults to 0.01.
is_conservative (bool): Whether to use conservative updates for offline learning.
Defaults to False.
conservative_alpha (float, optional): Alpha parameter for conservative updates.
Defaults to 2.0.
network_type (Type[QValueNetwork]): Network type for the Q-value network. Defaults to
`VanillaQValueNetwork`. This means that by default, an instance of the class
`VanillaQValueNetwork` (or the specified `network_type` class) is created and used
for learning.
network_instance (QValueNetwork, optional): A network instance to be used as the
Q-value network. Defaults to None.
Note: This is an alternative to specifying a `network_type`. If provided, the
specified `network_type` is ignored and the input `network_instance` is used for
learning. Allows for custom implementations of Q-value networks.
action_representation_module (ActionRepresentationModule, optional): Optional module to
represent actions as a feature vector. Typically specified at the agent level.
Defaults to None.
"""
super(DeepTDLearning, self).__init__(
training_rounds=training_rounds,
batch_size=batch_size,
Expand Down Expand Up @@ -142,6 +182,23 @@ def act(
available_action_space: ActionSpace,
exploit: bool = False,
) -> Action:
"""
Selects an action from the available action space balancing between exploration and
exploitation.
This action can be (i) an 'exploit action', i.e. the optimal action given estimate of the
Q values or (ii) an 'exploratory action' obtained using the specified `exploration_module`.
Args:
subjective_state (SubjectiveState): Current subjective state.
available_action_space (ActionSpace): Available action space at the current state.
Note that Pearl allows for action spaces to change dynamically.
exploit (bool): When set to True, we output the exploit action (no exploration).
When set to False, the specified `exploration_module` is used to balance
between exploration and exploitation. Defaults to False.
Returns:
Action: An action from the available action space.
"""
# TODO: Assumes gym action space.
# Fix the available action space.
assert isinstance(available_action_space, DiscreteActionSpace)
Expand All @@ -167,45 +224,67 @@ def act(
if exploit:
return exploit_action

assert self._exploration_module is not None
return self._exploration_module.act(
subjective_state,
available_action_space,
exploit_action,
q_values,
subjective_state=subjective_state,
action_space=available_action_space,
exploit_action=exploit_action,
values=q_values,
)

@abstractmethod
def _get_next_state_values(
def get_next_state_values(
self, batch: TransitionBatch, batch_size: int
) -> torch.Tensor:
"""
For a given batch of transitions, returns Q-value targets for the Bellman equation.
Child classes should implement this method.
For example, this method in DQN returns
"max_{action in available_action_space} Q(next_state, action)".
"""
pass

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
"""
Batch learning with TD(0) style updates. Different implementations of the
`get_next_state_values` function correspond to the different RL algorithm implementations,
for example TD learning, DQN, Double DQN, Duelling DQN etc.
Args:
batch (TransitionBatch): batch of transitions
Returns:
Dict[str, Any]: dictionary with loss as the mean bellman error (across the batch).
"""
state_batch = batch.state # (batch_size x state_dim)
action_batch = batch.action
# (batch_size x action_dim)
action_batch = batch.action # (batch_size x action_dim)
reward_batch = batch.reward # (batch_size)
done_batch = batch.done # (batch_size)

batch_size = state_batch.shape[0]
# sanity check they have same batch_size
assert reward_batch.shape[0] == batch_size
assert done_batch.shape[0] == batch_size

state_action_values = self._Q.get_q_values(
state_batch=state_batch,
action_batch=action_batch,
curr_available_actions_batch=batch.curr_available_actions,
) # for duelling, this takes care of the mean subtraction for advantage estimation
)
# for duelling dqn, specifying the `curr_available_actions_batch` field takes care of
# the mean subtraction for advantage estimation

# Compute the Bellman Target
expected_state_action_values = (
self._get_next_state_values(batch, batch_size)
self.get_next_state_values(batch, batch_size)
* self._discount_factor
* (1 - done_batch.float())
) + reward_batch # (batch_size), r + gamma * V(s)

criterion = torch.nn.MSELoss()
bellman_loss = criterion(state_action_values, expected_state_action_values)

# Conservative TD updates for offline learning.
if self._is_conservative:
cql_loss = compute_cql_loss(self._Q, batch, batch_size)
loss = self._conservative_alpha * cql_loss + bellman_loss
Expand All @@ -217,7 +296,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
loss.backward()
self._optimizer.step()

# Target Network Update
# Target network update
if (self._training_steps + 1) % self._target_update_freq == 0:
update_target_network(self._Q_target, self._Q, self._soft_update_tau)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class DoubleDQN(DeepQLearning):
"""

@torch.no_grad()
def _get_next_state_values(
def get_next_state_values(
self, batch: TransitionBatch, batch_size: int
) -> torch.Tensor:
next_state_batch = batch.next_state # (batch_size x state_dim)
Expand Down
15 changes: 10 additions & 5 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,22 @@ def test_dueling_dqn(
env = GymEnvironment("CartPole-v1")
assert isinstance(env.action_space, DiscreteActionSpace)
num_actions = env.action_space.n
# We use a one hot representation for representing actions. So take
# action_dim = num_actions.
q_network = DuelingQValueNetwork(
state_dim=env.observation_space.shape[0],
action_dim=num_actions,
hidden_dims=[64],
output_dim=1,
state_dim=env.observation_space.shape[
0
], # dimension of state representation
action_dim=num_actions, # dimension of the action representation
hidden_dims=[64, 64], # dimensions of the intermediate layers
output_dim=1, # set to 1 (Q values are scalars)
)
agent = PearlAgent(
policy_learner=DeepQLearning(
state_dim=env.observation_space.shape[0],
action_space=env.action_space,
training_rounds=20,
training_rounds=10,
soft_update_tau=0.75,
network_instance=q_network,
batch_size=batch_size,
action_representation_module=OneHotActionTensorRepresentationModule(
Expand Down
6 changes: 3 additions & 3 deletions test/unit/with_pytorch/test_deep_td_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def test_double_dqn(self) -> None:
batch2 = dqn.preprocess_batch(copy.deepcopy(self.batch))
double_dqn._Q.apply(init_weights)
double_dqn._Q_target.apply(init_weights)
double_value = double_dqn._get_next_state_values(batch1, self.batch_size)
double_value = double_dqn.get_next_state_values(batch1, self.batch_size)

dqn._Q.load_state_dict(double_dqn._Q.state_dict())
dqn._Q_target.load_state_dict(double_dqn._Q_target.state_dict())
vanilla_value = dqn._get_next_state_values(batch2, self.batch_size)
vanilla_value = dqn.get_next_state_values(batch2, self.batch_size)
self.assertEqual(double_value.shape, vanilla_value.shape)
differ = torch.any(double_value != vanilla_value)
if differ:
Expand All @@ -99,7 +99,7 @@ def test_sarsa(self) -> None:
exploration_module=EGreedyExploration(0.05),
action_representation_module=self.action_representation_module,
)
sa_value = sarsa._get_next_state_values(
sa_value = sarsa.get_next_state_values(
batch=sarsa.preprocess_batch(self.batch), batch_size=self.batch_size
)
self.assertEqual(sa_value.shape, (self.batch_size,))

0 comments on commit 73890da

Please sign in to comment.