Skip to content

Commit

Permalink
Adding network instance to actor critic base class
Browse files Browse the repository at this point in the history
Summary:
This diff makes the following changes:

- Added the option of passing instances of actor and critic networks as arguments for all actor critic algorithms.

-  Added integration tests for SAC, TD3, PPO and REINFORCE to check for performance with actor and critic network instances. A test for DDPG is not required since TD3 is based on DDPG.

- These integration tests should be come the default once we migrate all algorithms to work with network instances.

- Added more documentation for the actor-critic base class.

- Removed some duplicate code for instantiating the target network for actor and critic networks.

- Removed a unit test for ppo named "test_training_round_setup". This test was useful when ppo was based on reinforce. Now, ppo is based on the actor-critic base class and so this test is no longer needed.

Reviewed By: rodrigodesalvobraz

Differential Revision: D54978158

fbshipit-source-id: e95bea179f2698c53097c9e83cbb55871f2607bc
  • Loading branch information
jb3618 authored and facebook-github-bot committed Mar 23, 2024
1 parent f5b0945 commit bca09bd
Show file tree
Hide file tree
Showing 10 changed files with 414 additions and 82 deletions.
100 changes: 55 additions & 45 deletions pearl/policy_learners/sequential_decision_making/actor_critic_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

import copy
from abc import abstractmethod
from typing import Any, cast, Dict, Iterable, List, Optional, Type, Union

Expand Down Expand Up @@ -68,7 +69,8 @@ def __init__(
self,
state_dim: int,
exploration_module: ExplorationModule,
actor_hidden_dims: List[int],
actor_hidden_dims: Optional[List[int]] = None,
use_critic: bool = True,
critic_hidden_dims: Optional[List[int]] = None,
action_space: Optional[ActionSpace] = None,
actor_learning_rate: float = 1e-3,
Expand All @@ -88,6 +90,10 @@ def __init__(
is_action_continuous: bool = False,
on_policy: bool = False,
action_representation_module: Optional[ActionRepresentationModule] = None,
actor_network_instance: Optional[ActorNetwork] = None,
critic_network_instance: Optional[
Union[ValueNetwork, QValueNetwork, nn.Module]
] = None,
) -> None:
super(ActorCriticBase, self).__init__(
on_policy=on_policy,
Expand All @@ -98,46 +104,32 @@ def __init__(
action_representation_module=action_representation_module,
action_space=action_space,
)
"""
Constructs a base actor-critic policy learner.
"""

self._state_dim = state_dim
self._use_actor_target = use_actor_target
self._use_critic_target = use_critic_target
self._use_twin_critic = use_twin_critic
self._use_critic: bool = critic_hidden_dims is not None
self._use_critic: bool = use_critic

self._action_dim: int = (
self.action_representation_module.representation_dim
if self.is_action_continuous
else self.action_representation_module.max_number_actions
)

# actor network takes state as input and outputs an action vector
self._actor: nn.Module = actor_network_type(
input_dim=(
state_dim + self._action_dim
if actor_network_type is DynamicActionActorNetwork
else state_dim
),
hidden_dims=actor_hidden_dims,
output_dim=(
1
if actor_network_type is DynamicActionActorNetwork
else self._action_dim
),
action_space=action_space,
)
self._actor.apply(init_weights)
self._actor_optimizer = optim.AdamW(
[
{
"params": self._actor.parameters(),
"lr": actor_learning_rate,
"amsgrad": True,
},
]
)
self._actor_soft_update_tau = actor_soft_update_tau
if self._use_actor_target:
self._actor_target: nn.Module = actor_network_type(
if actor_network_instance is not None:
self._actor: nn.Module = actor_network_instance
else:
assert (
actor_hidden_dims is not None
), f"{self.__class__.__name__} requires parameter actor_hidden_dims if a parameter \
action_network_instance has not been provided."

# actor network takes state as input and outputs an action vector
self._actor: nn.Module = actor_network_type(
input_dim=(
state_dim + self._action_dim
if actor_network_type is DynamicActionActorNetwork
Expand All @@ -151,17 +143,41 @@ def __init__(
),
action_space=action_space,
)
self._actor.apply(init_weights)
self._actor_optimizer = optim.AdamW(
[
{
"params": self._actor.parameters(),
"lr": actor_learning_rate,
"amsgrad": True,
},
]
)
self._actor_soft_update_tau = actor_soft_update_tau

# make a copy of the actor network to be used as the actor target network
if self._use_actor_target:
self._actor_target: nn.Module = copy.deepcopy(self._actor)
update_target_network(self._actor_target, self._actor, tau=1)

self._critic_soft_update_tau = critic_soft_update_tau
if self._use_critic:
self._critic: nn.Module = make_critic(
state_dim=self._state_dim,
action_dim=self._action_dim,
hidden_dims=critic_hidden_dims,
use_twin_critic=use_twin_critic,
network_type=critic_network_type,
)
if critic_network_instance is not None:
self._critic: nn.Module = critic_network_instance
else:
assert (
critic_hidden_dims is not None
), f"{self.__class__.__name__} requires parameter critic_hidden_dims if a \
parameter critic_network_instance has not been provided."

self._critic: nn.Module = make_critic(
state_dim=self._state_dim,
action_dim=self._action_dim,
hidden_dims=critic_hidden_dims,
use_twin_critic=use_twin_critic,
network_type=critic_network_type,
)

self._critic_optimizer: optim.Optimizer = optim.AdamW(
[
{
Expand All @@ -172,13 +188,7 @@ def __init__(
]
)
if self._use_critic_target:
self._critic_target: nn.Module = make_critic(
state_dim=self._state_dim,
action_dim=self._action_dim,
hidden_dims=critic_hidden_dims,
use_twin_critic=use_twin_critic,
network_type=critic_network_type,
)
self._critic_target: nn.Module = copy.deepcopy(self._critic)
update_critic_target_network(
self._critic_target,
self._critic,
Expand Down Expand Up @@ -407,7 +417,7 @@ def make_critic(
)
else:
raise NotImplementedError(
"Unknown network type. The code needs to be refactored to support this."
f"Type {network_type} cannot be used to instantiate a critic network."
)


Expand Down
11 changes: 8 additions & 3 deletions pearl/policy_learners/sequential_decision_making/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import List, Optional, Type
from typing import List, Optional, Type, Union

import torch
from pearl.action_representation_modules.action_representation_module import (
Expand Down Expand Up @@ -35,6 +35,7 @@
twin_critic_action_value_loss,
)
from pearl.replay_buffers.transition import TransitionBatch
from torch import nn


class DeepDeterministicPolicyGradient(ActorCriticBase):
Expand All @@ -47,8 +48,8 @@ def __init__(
self,
state_dim: int,
action_space: ActionSpace,
actor_hidden_dims: List[int],
critic_hidden_dims: List[int],
actor_hidden_dims: Optional[List[int]] = None,
critic_hidden_dims: Optional[List[int]] = None,
exploration_module: Optional[ExplorationModule] = None,
actor_learning_rate: float = 1e-3,
critic_learning_rate: float = 1e-3,
Expand All @@ -60,6 +61,8 @@ def __init__(
training_rounds: int = 1,
batch_size: int = 256,
action_representation_module: Optional[ActionRepresentationModule] = None,
actor_network_instance: Optional[ActorNetwork] = None,
critic_network_instance: Optional[Union[QValueNetwork, nn.Module]] = None,
) -> None:
super(DeepDeterministicPolicyGradient, self).__init__(
state_dim=state_dim,
Expand All @@ -86,6 +89,8 @@ def __init__(
is_action_continuous=True,
on_policy=False,
action_representation_module=action_representation_module,
actor_network_instance=actor_network_instance,
critic_network_instance=critic_network_instance,
)

def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
Expand Down
13 changes: 10 additions & 3 deletions pearl/policy_learners/sequential_decision_making/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type, Union

import torch
from pearl.action_representation_modules.action_representation_module import (
Expand Down Expand Up @@ -40,6 +40,7 @@
OnPolicyTransitionBatch,
)
from pearl.replay_buffers.transition import TransitionBatch
from torch import nn


class ProximalPolicyOptimization(ActorCriticBase):
Expand All @@ -51,8 +52,9 @@ def __init__(
self,
state_dim: int,
action_space: ActionSpace,
actor_hidden_dims: List[int],
critic_hidden_dims: Optional[List[int]],
use_critic: bool,
actor_hidden_dims: Optional[List[int]] = None,
critic_hidden_dims: Optional[List[int]] = None,
actor_learning_rate: float = 1e-4,
critic_learning_rate: float = 1e-4,
exploration_module: Optional[ExplorationModule] = None,
Expand All @@ -65,11 +67,14 @@ def __init__(
trace_decay_param: float = 0.95,
entropy_bonus_scaling: float = 0.01,
action_representation_module: Optional[ActionRepresentationModule] = None,
actor_network_instance: Optional[ActorNetwork] = None,
critic_network_instance: Optional[Union[ValueNetwork, nn.Module]] = None,
) -> None:
super(ProximalPolicyOptimization, self).__init__(
state_dim=state_dim,
action_space=action_space,
actor_hidden_dims=actor_hidden_dims,
use_critic=use_critic,
critic_hidden_dims=critic_hidden_dims,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
Expand All @@ -91,6 +96,8 @@ def __init__(
is_action_continuous=False,
on_policy=True,
action_representation_module=action_representation_module,
actor_network_instance=actor_network_instance,
critic_network_instance=critic_network_instance,
)
self._epsilon = epsilon
self._trace_decay_param = trace_decay_param
Expand Down
11 changes: 9 additions & 2 deletions pearl/policy_learners/sequential_decision_making/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type, Union

from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
Expand All @@ -16,6 +16,7 @@
from pearl.neural_networks.common.value_networks import ValueNetwork

from pearl.neural_networks.sequential_decision_making.actor_networks import ActorNetwork
from torch import nn

try:
import gymnasium as gym
Expand Down Expand Up @@ -57,7 +58,8 @@ class REINFORCE(ActorCriticBase):
def __init__(
self,
state_dim: int,
actor_hidden_dims: List[int],
actor_hidden_dims: Optional[List[int]] = None,
use_critic: bool = False,
critic_hidden_dims: Optional[List[int]] = None,
action_space: Optional[ActionSpace] = None,
actor_learning_rate: float = 1e-4,
Expand All @@ -69,11 +71,14 @@ def __init__(
training_rounds: int = 8,
batch_size: int = 64,
action_representation_module: Optional[ActionRepresentationModule] = None,
actor_network_instance: Optional[ActorNetwork] = None,
critic_network_instance: Optional[Union[ValueNetwork, nn.Module]] = None,
) -> None:
super(REINFORCE, self).__init__(
state_dim=state_dim,
action_space=action_space,
actor_hidden_dims=actor_hidden_dims,
use_critic=use_critic,
critic_hidden_dims=critic_hidden_dims,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
Expand All @@ -95,6 +100,8 @@ def __init__(
is_action_continuous=False,
on_policy=True,
action_representation_module=action_representation_module,
actor_network_instance=actor_network_instance,
critic_network_instance=critic_network_instance,
)

def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import List, Optional, Type
from typing import List, Optional, Type, Union

import torch
from pearl.action_representation_modules.action_representation_module import (
Expand Down Expand Up @@ -36,7 +36,7 @@
twin_critic_action_value_loss,
)
from pearl.replay_buffers.transition import TransitionBatch
from torch import optim
from torch import nn, optim


# Currently available actions is not used. Needs to be updated once we know the input
Expand All @@ -53,8 +53,8 @@ def __init__(
self,
state_dim: int,
action_space: ActionSpace,
actor_hidden_dims: List[int],
critic_hidden_dims: List[int],
actor_hidden_dims: Optional[List[int]] = None,
critic_hidden_dims: Optional[List[int]] = None,
actor_learning_rate: float = 1e-4,
critic_learning_rate: float = 1e-4,
actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
Expand All @@ -66,6 +66,8 @@ def __init__(
batch_size: int = 128,
entropy_coef: float = 0.2,
action_representation_module: Optional[ActionRepresentationModule] = None,
actor_network_instance: Optional[ActorNetwork] = None,
critic_network_instance: Optional[Union[QValueNetwork, nn.Module]] = None,
) -> None:
super(SoftActorCritic, self).__init__(
state_dim=state_dim,
Expand All @@ -92,6 +94,8 @@ def __init__(
is_action_continuous=False,
on_policy=False,
action_representation_module=action_representation_module,
actor_network_instance=actor_network_instance,
critic_network_instance=critic_network_instance,
)

# This is needed to avoid actor softmax overflow issue.
Expand Down
Loading

0 comments on commit bca09bd

Please sign in to comment.