Skip to content

Commit

Permalink
taking utility functions out of the actor critic base class
Browse files Browse the repository at this point in the history
Summary:
There were four utility functions in the actor critic base class, `make_critic`, `update_critic_target_network`, `single_critic_state_value_loss` and `twin_critic_action_value_loss` which are also used by the reward constrained safety module in addition to the actor-critic algorithms. This diff moves them to a separate file called 'critic_utils.py' in the utils folder, since they seem to be general utility functions what can be used by different modules in pearl.

I added some documentation for the `make_critic` and `update_critic_target_network` function. I also modified the `update_critic_target_network` function to not input `use_twin_critic` as an input argument as we can infer from the type of the `target_network` if `TwinCritic` was used.

Reviewed By: rodrigodesalvobraz

Differential Revision: D55161398

fbshipit-source-id: a598b24ee1d748195513b29ffe0078790d04b05c
  • Loading branch information
jb3618 authored and facebook-github-bot committed Mar 24, 2024
1 parent bca09bd commit abb6308
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 161 deletions.
150 changes: 8 additions & 142 deletions pearl/policy_learners/sequential_decision_making/actor_critic_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import copy
from abc import abstractmethod
from typing import Any, cast, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type, Union

import torch

Expand All @@ -24,15 +24,8 @@
from pearl.history_summarization_modules.history_summarization_module import (
HistorySummarizationModule,
)
from pearl.neural_networks.common.utils import (
init_weights,
update_target_network,
update_target_networks,
)
from pearl.neural_networks.common.value_networks import (
ValueNetwork,
VanillaValueNetwork,
)
from pearl.neural_networks.common.utils import init_weights, update_target_network
from pearl.neural_networks.common.value_networks import ValueNetwork
from pearl.neural_networks.sequential_decision_making.actor_networks import (
ActorNetwork,
DynamicActionActorNetwork,
Expand All @@ -42,12 +35,16 @@
QValueNetwork,
VanillaQValueNetwork,
)
from pearl.neural_networks.sequential_decision_making.twin_critic import TwinCritic

from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.critic_utils import (
make_critic,
update_critic_target_network,
)

from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
from torch import nn, optim
Expand Down Expand Up @@ -192,7 +189,6 @@ def __init__(
update_critic_target_network(
self._critic_target,
self._critic,
use_twin_critic,
1,
)

Expand Down Expand Up @@ -322,7 +318,6 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
update_critic_target_network(
self._critic_target,
self._critic,
self._use_twin_critic,
self._critic_soft_update_tau,
)
if self._use_actor_target:
Expand Down Expand Up @@ -372,132 +367,3 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
loss (Tensor): The critic loss.
"""
pass


def make_critic(
state_dim: int,
hidden_dims: Optional[Iterable[int]],
use_twin_critic: bool,
network_type: Union[Type[ValueNetwork], Type[QValueNetwork]],
action_dim: Optional[int] = None,
) -> nn.Module:
if use_twin_critic:
assert action_dim is not None
assert hidden_dims is not None
assert issubclass(
network_type, QValueNetwork
), "network_type must be a subclass of QValueNetwork when use_twin_critic is True"

# cast network_type to get around static Pyre type checking; the runtime check with
# `issubclass` ensures the network type is a sublcass of QValueNetwork
network_type = cast(Type[QValueNetwork], network_type)

return TwinCritic(
state_dim=state_dim,
action_dim=action_dim,
hidden_dims=hidden_dims,
network_type=network_type,
init_fn=init_weights,
)
else:
if network_type == VanillaQValueNetwork:
# pyre-ignore[45]:
# Pyre does not know that `network_type` is asserted to be concrete
return network_type(
state_dim=state_dim,
action_dim=action_dim,
hidden_dims=hidden_dims,
output_dim=1,
)
elif network_type == VanillaValueNetwork:
# pyre-ignore[45]:
# Pyre does not know that `network_type` is asserted to be concrete
return network_type(
input_dim=state_dim, hidden_dims=hidden_dims, output_dim=1
)
else:
raise NotImplementedError(
f"Type {network_type} cannot be used to instantiate a critic network."
)


def update_critic_target_network(
target_network: nn.Module, network: nn.Module, use_twin_critic: bool, tau: float
) -> None:
if use_twin_critic:
update_target_networks(
target_network._critic_networks_combined,
network._critic_networks_combined,
tau=tau,
)
else:
update_target_network(
target_network._model,
network._model,
tau=tau,
)


def single_critic_state_value_loss(
state_batch: torch.Tensor,
expected_target_batch: torch.Tensor,
critic: nn.Module,
) -> torch.Tensor:
"""
Performs a single optimization step on a (value) critic network using the input batch of states.
This method calculates the mean squared error loss between the predicted state values from the
critic network and the input target estimates. It then updates the critic network using the
provided optimizer.
Args:
state_batch (torch.Tensor): A batch of states with expected shape
`(batch_size, state_dim)`.
expected_target_batch (torch.Tensor): The batch of target estimates
(i.e., RHS of the Bellman equation) with shape `(batch_size)`.
critic (nn.Module): The critic network to update.
Returns:
loss (torch.Tensor): The mean squared error loss for state-value prediction
"""
if not isinstance(critic, ValueNetwork):
raise TypeError(
"critic in the `single_critic_state_value_update` method must be an instance of "
"ValueNetwork"
)
vs = critic(state_batch)
criterion = torch.nn.MSELoss()
loss = criterion(
vs.reshape_as(expected_target_batch), expected_target_batch.detach()
)
return loss


def twin_critic_action_value_loss(
state_batch: torch.Tensor,
action_batch: torch.Tensor,
expected_target_batch: torch.Tensor,
critic: TwinCritic,
) -> torch.Tensor:
"""
Performs a single optimization step on the twin critic networks using the input
batch of states and actions.
This method calculates the mean squared error loss between the predicted Q-values from both
critic networks and the input target estimates. It then updates the critic networks using the
provided optimizer.
Args:
state_batch (torch.Tensor): A batch of states with expected shape
`(batch_size, state_dim)`.
action_batch (torch.Tensor): A batch of actions with expected shape
`(batch_size, action_dim)`.
expected_target_batch (torch.Tensor): The batch of target estimates
(i.e. RHS of the Bellman equation) with shape `(batch_size)`.
critic (TwinCritic): The twin critic network to update.
Returns:
loss (torch.Tensor): The mean squared error loss for action-value prediction
"""

criterion = torch.nn.MSELoss()
q_1, q_2 = critic.get_q_values(state_batch, action_batch)
loss = criterion(
q_1.reshape_as(expected_target_batch), expected_target_batch.detach()
) + criterion(q_2.reshape_as(expected_target_batch), expected_target_batch.detach())
return loss
4 changes: 3 additions & 1 deletion pearl/policy_learners/sequential_decision_making/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@
)
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
ActorCriticBase,
twin_critic_action_value_loss,
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.critic_utils import (
twin_critic_action_value_loss,
)
from torch import nn


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@
)
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
ActorCriticBase,
twin_critic_action_value_loss,
)

from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.critic_utils import (
twin_critic_action_value_loss,
)
from torch import optim


Expand Down
5 changes: 4 additions & 1 deletion pearl/policy_learners/sequential_decision_making/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
)
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
ActorCriticBase,
single_critic_state_value_loss,
)
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.sequential_decision_making.on_policy_replay_buffer import (
Expand All @@ -40,6 +39,10 @@
OnPolicyTransitionBatch,
)
from pearl.replay_buffers.transition import TransitionBatch

from pearl.utils.functional_utils.learning.critic_utils import (
single_critic_state_value_loss,
)
from torch import nn


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
)
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
ActorCriticBase,
single_critic_state_value_loss,
)
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.sequential_decision_making.on_policy_replay_buffer import (
Expand All @@ -46,6 +45,9 @@
OnPolicyTransitionBatch,
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.critic_utils import (
single_critic_state_value_loss,
)


class REINFORCE(ActorCriticBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@
)
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
ActorCriticBase,
twin_critic_action_value_loss,
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.critic_utils import (
twin_critic_action_value_loss,
)
from torch import nn, optim


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
)
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
ActorCriticBase,
twin_critic_action_value_loss,
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.critic_utils import (
twin_critic_action_value_loss,
)
from pearl.utils.instantiations.spaces.box import BoxSpace
from torch import nn, optim

Expand Down
9 changes: 4 additions & 5 deletions pearl/policy_learners/sequential_decision_making/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
twin_critic_action_value_loss,
update_critic_target_network,
)
from pearl.policy_learners.sequential_decision_making.ddpg import (
DeepDeterministicPolicyGradient,
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.critic_utils import (
twin_critic_action_value_loss,
update_critic_target_network,
)
from pearl.utils.instantiations.spaces.box_action import BoxActionSpace
from torch import nn

Expand Down Expand Up @@ -122,7 +122,6 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
update_critic_target_network(
self._critic_target,
self._critic,
self._use_twin_critic,
self._critic_soft_update_tau,
)
# update target of actor network using soft updates
Expand Down
13 changes: 6 additions & 7 deletions pearl/safety_modules/reward_constrained_safety_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
)
from pearl.neural_networks.sequential_decision_making.twin_critic import TwinCritic
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.transition import TransitionBatch
from pearl.safety_modules.safety_module import SafetyModule
from pearl.utils.functional_utils.learning.critic_utils import (
make_critic,
twin_critic_action_value_loss,
update_critic_target_network,
)
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.transition import TransitionBatch
from pearl.safety_modules.safety_module import SafetyModule
from torch import nn, optim


Expand Down Expand Up @@ -103,7 +103,6 @@ def __init__(
update_critic_target_network(
self.target_of_cost_critic,
self.cost_critic,
self.use_twin_critic,
1,
)

Expand Down Expand Up @@ -133,7 +132,8 @@ def constraint_lambda_update(
cost_critic: nn.Module,
) -> None:
"""
Update the lambda constraint based on the cost critic via a projected gradient descent update rule.
Update the lambda constraint based on the cost critic via a projected gradient descent
update rule.
"""

with torch.no_grad():
Expand Down Expand Up @@ -191,7 +191,6 @@ def cost_critic_learn_batch(
update_critic_target_network(
self.target_of_cost_critic,
self.cost_critic,
self.use_twin_critic,
self.critic_soft_update_tau,
)

Expand Down
Loading

0 comments on commit abb6308

Please sign in to comment.