diff --git a/pearl/policy_learners/sequential_decision_making/actor_critic_base.py b/pearl/policy_learners/sequential_decision_making/actor_critic_base.py index 66b978df..b9e33dbc 100644 --- a/pearl/policy_learners/sequential_decision_making/actor_critic_base.py +++ b/pearl/policy_learners/sequential_decision_making/actor_critic_base.py @@ -7,6 +7,7 @@ # pyre-strict +import copy from abc import abstractmethod from typing import Any, cast, Dict, Iterable, List, Optional, Type, Union @@ -68,7 +69,8 @@ def __init__( self, state_dim: int, exploration_module: ExplorationModule, - actor_hidden_dims: List[int], + actor_hidden_dims: Optional[List[int]] = None, + use_critic: bool = True, critic_hidden_dims: Optional[List[int]] = None, action_space: Optional[ActionSpace] = None, actor_learning_rate: float = 1e-3, @@ -88,6 +90,10 @@ def __init__( is_action_continuous: bool = False, on_policy: bool = False, action_representation_module: Optional[ActionRepresentationModule] = None, + actor_network_instance: Optional[ActorNetwork] = None, + critic_network_instance: Optional[ + Union[ValueNetwork, QValueNetwork, nn.Module] + ] = None, ) -> None: super(ActorCriticBase, self).__init__( on_policy=on_policy, @@ -98,11 +104,15 @@ def __init__( action_representation_module=action_representation_module, action_space=action_space, ) + """ + Constructs a base actor-critic policy learner. + """ + self._state_dim = state_dim self._use_actor_target = use_actor_target self._use_critic_target = use_critic_target self._use_twin_critic = use_twin_critic - self._use_critic: bool = critic_hidden_dims is not None + self._use_critic: bool = use_critic self._action_dim: int = ( self.action_representation_module.representation_dim @@ -110,34 +120,16 @@ def __init__( else self.action_representation_module.max_number_actions ) - # actor network takes state as input and outputs an action vector - self._actor: nn.Module = actor_network_type( - input_dim=( - state_dim + self._action_dim - if actor_network_type is DynamicActionActorNetwork - else state_dim - ), - hidden_dims=actor_hidden_dims, - output_dim=( - 1 - if actor_network_type is DynamicActionActorNetwork - else self._action_dim - ), - action_space=action_space, - ) - self._actor.apply(init_weights) - self._actor_optimizer = optim.AdamW( - [ - { - "params": self._actor.parameters(), - "lr": actor_learning_rate, - "amsgrad": True, - }, - ] - ) - self._actor_soft_update_tau = actor_soft_update_tau - if self._use_actor_target: - self._actor_target: nn.Module = actor_network_type( + if actor_network_instance is not None: + self._actor: nn.Module = actor_network_instance + else: + assert ( + actor_hidden_dims is not None + ), f"{self.__class__.__name__} requires parameter actor_hidden_dims if a parameter \ + action_network_instance has not been provided." + + # actor network takes state as input and outputs an action vector + self._actor: nn.Module = actor_network_type( input_dim=( state_dim + self._action_dim if actor_network_type is DynamicActionActorNetwork @@ -151,17 +143,41 @@ def __init__( ), action_space=action_space, ) + self._actor.apply(init_weights) + self._actor_optimizer = optim.AdamW( + [ + { + "params": self._actor.parameters(), + "lr": actor_learning_rate, + "amsgrad": True, + }, + ] + ) + self._actor_soft_update_tau = actor_soft_update_tau + + # make a copy of the actor network to be used as the actor target network + if self._use_actor_target: + self._actor_target: nn.Module = copy.deepcopy(self._actor) update_target_network(self._actor_target, self._actor, tau=1) self._critic_soft_update_tau = critic_soft_update_tau if self._use_critic: - self._critic: nn.Module = make_critic( - state_dim=self._state_dim, - action_dim=self._action_dim, - hidden_dims=critic_hidden_dims, - use_twin_critic=use_twin_critic, - network_type=critic_network_type, - ) + if critic_network_instance is not None: + self._critic: nn.Module = critic_network_instance + else: + assert ( + critic_hidden_dims is not None + ), f"{self.__class__.__name__} requires parameter critic_hidden_dims if a \ + parameter critic_network_instance has not been provided." + + self._critic: nn.Module = make_critic( + state_dim=self._state_dim, + action_dim=self._action_dim, + hidden_dims=critic_hidden_dims, + use_twin_critic=use_twin_critic, + network_type=critic_network_type, + ) + self._critic_optimizer: optim.Optimizer = optim.AdamW( [ { @@ -172,13 +188,7 @@ def __init__( ] ) if self._use_critic_target: - self._critic_target: nn.Module = make_critic( - state_dim=self._state_dim, - action_dim=self._action_dim, - hidden_dims=critic_hidden_dims, - use_twin_critic=use_twin_critic, - network_type=critic_network_type, - ) + self._critic_target: nn.Module = copy.deepcopy(self._critic) update_critic_target_network( self._critic_target, self._critic, @@ -407,7 +417,7 @@ def make_critic( ) else: raise NotImplementedError( - "Unknown network type. The code needs to be refactored to support this." + f"Type {network_type} cannot be used to instantiate a critic network." ) diff --git a/pearl/policy_learners/sequential_decision_making/ddpg.py b/pearl/policy_learners/sequential_decision_making/ddpg.py index 90d96eca..9820f553 100644 --- a/pearl/policy_learners/sequential_decision_making/ddpg.py +++ b/pearl/policy_learners/sequential_decision_making/ddpg.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import List, Optional, Type +from typing import List, Optional, Type, Union import torch from pearl.action_representation_modules.action_representation_module import ( @@ -35,6 +35,7 @@ twin_critic_action_value_loss, ) from pearl.replay_buffers.transition import TransitionBatch +from torch import nn class DeepDeterministicPolicyGradient(ActorCriticBase): @@ -47,8 +48,8 @@ def __init__( self, state_dim: int, action_space: ActionSpace, - actor_hidden_dims: List[int], - critic_hidden_dims: List[int], + actor_hidden_dims: Optional[List[int]] = None, + critic_hidden_dims: Optional[List[int]] = None, exploration_module: Optional[ExplorationModule] = None, actor_learning_rate: float = 1e-3, critic_learning_rate: float = 1e-3, @@ -60,6 +61,8 @@ def __init__( training_rounds: int = 1, batch_size: int = 256, action_representation_module: Optional[ActionRepresentationModule] = None, + actor_network_instance: Optional[ActorNetwork] = None, + critic_network_instance: Optional[Union[QValueNetwork, nn.Module]] = None, ) -> None: super(DeepDeterministicPolicyGradient, self).__init__( state_dim=state_dim, @@ -86,6 +89,8 @@ def __init__( is_action_continuous=True, on_policy=False, action_representation_module=action_representation_module, + actor_network_instance=actor_network_instance, + critic_network_instance=critic_network_instance, ) def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: diff --git a/pearl/policy_learners/sequential_decision_making/ppo.py b/pearl/policy_learners/sequential_decision_making/ppo.py index 945a9b8a..75a6233a 100644 --- a/pearl/policy_learners/sequential_decision_making/ppo.py +++ b/pearl/policy_learners/sequential_decision_making/ppo.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union import torch from pearl.action_representation_modules.action_representation_module import ( @@ -40,6 +40,7 @@ OnPolicyTransitionBatch, ) from pearl.replay_buffers.transition import TransitionBatch +from torch import nn class ProximalPolicyOptimization(ActorCriticBase): @@ -51,8 +52,9 @@ def __init__( self, state_dim: int, action_space: ActionSpace, - actor_hidden_dims: List[int], - critic_hidden_dims: Optional[List[int]], + use_critic: bool, + actor_hidden_dims: Optional[List[int]] = None, + critic_hidden_dims: Optional[List[int]] = None, actor_learning_rate: float = 1e-4, critic_learning_rate: float = 1e-4, exploration_module: Optional[ExplorationModule] = None, @@ -65,11 +67,14 @@ def __init__( trace_decay_param: float = 0.95, entropy_bonus_scaling: float = 0.01, action_representation_module: Optional[ActionRepresentationModule] = None, + actor_network_instance: Optional[ActorNetwork] = None, + critic_network_instance: Optional[Union[ValueNetwork, nn.Module]] = None, ) -> None: super(ProximalPolicyOptimization, self).__init__( state_dim=state_dim, action_space=action_space, actor_hidden_dims=actor_hidden_dims, + use_critic=use_critic, critic_hidden_dims=critic_hidden_dims, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, @@ -91,6 +96,8 @@ def __init__( is_action_continuous=False, on_policy=True, action_representation_module=action_representation_module, + actor_network_instance=actor_network_instance, + critic_network_instance=critic_network_instance, ) self._epsilon = epsilon self._trace_decay_param = trace_decay_param diff --git a/pearl/policy_learners/sequential_decision_making/reinforce.py b/pearl/policy_learners/sequential_decision_making/reinforce.py index 309f76c6..7b5be7b7 100644 --- a/pearl/policy_learners/sequential_decision_making/reinforce.py +++ b/pearl/policy_learners/sequential_decision_making/reinforce.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union from pearl.action_representation_modules.action_representation_module import ( ActionRepresentationModule, @@ -16,6 +16,7 @@ from pearl.neural_networks.common.value_networks import ValueNetwork from pearl.neural_networks.sequential_decision_making.actor_networks import ActorNetwork +from torch import nn try: import gymnasium as gym @@ -57,7 +58,8 @@ class REINFORCE(ActorCriticBase): def __init__( self, state_dim: int, - actor_hidden_dims: List[int], + actor_hidden_dims: Optional[List[int]] = None, + use_critic: bool = False, critic_hidden_dims: Optional[List[int]] = None, action_space: Optional[ActionSpace] = None, actor_learning_rate: float = 1e-4, @@ -69,11 +71,14 @@ def __init__( training_rounds: int = 8, batch_size: int = 64, action_representation_module: Optional[ActionRepresentationModule] = None, + actor_network_instance: Optional[ActorNetwork] = None, + critic_network_instance: Optional[Union[ValueNetwork, nn.Module]] = None, ) -> None: super(REINFORCE, self).__init__( state_dim=state_dim, action_space=action_space, actor_hidden_dims=actor_hidden_dims, + use_critic=use_critic, critic_hidden_dims=critic_hidden_dims, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, @@ -95,6 +100,8 @@ def __init__( is_action_continuous=False, on_policy=True, action_representation_module=action_representation_module, + actor_network_instance=actor_network_instance, + critic_network_instance=critic_network_instance, ) def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor: diff --git a/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py b/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py index f9679435..821839dd 100644 --- a/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py +++ b/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import List, Optional, Type +from typing import List, Optional, Type, Union import torch from pearl.action_representation_modules.action_representation_module import ( @@ -36,7 +36,7 @@ twin_critic_action_value_loss, ) from pearl.replay_buffers.transition import TransitionBatch -from torch import optim +from torch import nn, optim # Currently available actions is not used. Needs to be updated once we know the input @@ -53,8 +53,8 @@ def __init__( self, state_dim: int, action_space: ActionSpace, - actor_hidden_dims: List[int], - critic_hidden_dims: List[int], + actor_hidden_dims: Optional[List[int]] = None, + critic_hidden_dims: Optional[List[int]] = None, actor_learning_rate: float = 1e-4, critic_learning_rate: float = 1e-4, actor_network_type: Type[ActorNetwork] = VanillaActorNetwork, @@ -66,6 +66,8 @@ def __init__( batch_size: int = 128, entropy_coef: float = 0.2, action_representation_module: Optional[ActionRepresentationModule] = None, + actor_network_instance: Optional[ActorNetwork] = None, + critic_network_instance: Optional[Union[QValueNetwork, nn.Module]] = None, ) -> None: super(SoftActorCritic, self).__init__( state_dim=state_dim, @@ -92,6 +94,8 @@ def __init__( is_action_continuous=False, on_policy=False, action_representation_module=action_representation_module, + actor_network_instance=actor_network_instance, + critic_network_instance=critic_network_instance, ) # This is needed to avoid actor softmax overflow issue. diff --git a/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py b/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py index 69ef8084..9c311c79 100644 --- a/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py +++ b/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union import torch from pearl.action_representation_modules.action_representation_module import ( @@ -34,7 +34,7 @@ ) from pearl.replay_buffers.transition import TransitionBatch from pearl.utils.instantiations.spaces.box import BoxSpace -from torch import optim +from torch import nn, optim class ContinuousSoftActorCritic(ActorCriticBase): @@ -46,8 +46,8 @@ def __init__( self, state_dim: int, action_space: ActionSpace, - actor_hidden_dims: List[int], - critic_hidden_dims: List[int], + actor_hidden_dims: Optional[List[int]] = None, + critic_hidden_dims: Optional[List[int]] = None, actor_learning_rate: float = 1e-3, critic_learning_rate: float = 1e-3, actor_network_type: Type[ActorNetwork] = GaussianActorNetwork, @@ -60,6 +60,8 @@ def __init__( entropy_coef: float = 0.2, entropy_autotune: bool = True, action_representation_module: Optional[ActionRepresentationModule] = None, + actor_network_instance: Optional[ActorNetwork] = None, + critic_network_instance: Optional[Union[QValueNetwork, nn.Module]] = None, ) -> None: super(ContinuousSoftActorCritic, self).__init__( state_dim=state_dim, @@ -86,6 +88,8 @@ def __init__( is_action_continuous=True, on_policy=False, action_representation_module=action_representation_module, + actor_network_instance=actor_network_instance, + critic_network_instance=critic_network_instance, ) self._entropy_autotune = entropy_autotune diff --git a/pearl/policy_learners/sequential_decision_making/td3.py b/pearl/policy_learners/sequential_decision_making/td3.py index 79d3b66c..ca40b252 100644 --- a/pearl/policy_learners/sequential_decision_making/td3.py +++ b/pearl/policy_learners/sequential_decision_making/td3.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union import torch from pearl.action_representation_modules.action_representation_module import ( @@ -36,6 +36,7 @@ ) from pearl.replay_buffers.transition import TransitionBatch from pearl.utils.instantiations.spaces.box_action import BoxActionSpace +from torch import nn class TD3(DeepDeterministicPolicyGradient): @@ -49,8 +50,8 @@ def __init__( self, state_dim: int, action_space: ActionSpace, - actor_hidden_dims: List[int], - critic_hidden_dims: List[int], + actor_hidden_dims: Optional[List[int]] = None, + critic_hidden_dims: Optional[List[int]] = None, exploration_module: Optional[ExplorationModule] = None, actor_learning_rate: float = 1e-3, critic_learning_rate: float = 1e-3, @@ -65,6 +66,8 @@ def __init__( actor_update_noise: float = 0.2, actor_update_noise_clip: float = 0.5, action_representation_module: Optional[ActionRepresentationModule] = None, + actor_network_instance: Optional[ActorNetwork] = None, + critic_network_instance: Optional[Union[QValueNetwork, nn.Module]] = None, ) -> None: assert isinstance(action_space, BoxActionSpace) super(TD3, self).__init__( @@ -83,6 +86,8 @@ def __init__( training_rounds=training_rounds, batch_size=batch_size, action_representation_module=action_representation_module, + actor_network_instance=actor_network_instance, + critic_network_instance=critic_network_instance, ) self._action_space: BoxActionSpace = action_space self._actor_update_freq = actor_update_freq diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index bc39f49f..d2e4df98 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -7,6 +7,17 @@ # pyre-strict +from pearl.neural_networks.common.utils import init_weights +from pearl.neural_networks.common.value_networks import VanillaValueNetwork +from pearl.neural_networks.sequential_decision_making.actor_networks import ( + GaussianActorNetwork, + VanillaActorNetwork, + VanillaContinuousActorNetwork, +) +from pearl.neural_networks.sequential_decision_making.q_value_networks import ( + VanillaQValueNetwork, +) +from pearl.neural_networks.sequential_decision_making.twin_critic import TwinCritic from pearl.utils.instantiations.spaces.discrete import DiscreteSpace try: @@ -239,6 +250,7 @@ def test_reinforce(self) -> None: state_dim=env.observation_space.shape[0], action_space=env.action_space, actor_hidden_dims=[64, 64], + use_critic=True, critic_hidden_dims=[64, 64], training_rounds=8, batch_size=64, @@ -260,6 +272,59 @@ def test_reinforce(self) -> None: ) ) + def test_reinforce_network_instance(self) -> None: + """ + This test checks for performance of REINFORCE when actor and critic network instances are + passed as input arguments. The performance metric is if REINFORCE will eventually obtain + an episodic return of 500 for CartPole-v1. + """ + env = GymEnvironment("CartPole-v1") + assert isinstance(env.action_space, DiscreteActionSpace) + num_actions = env.action_space.n + action_representation_module = OneHotActionTensorRepresentationModule( + max_number_actions=num_actions + ) + + # VanillaActorNetwork outputs a probability distribution over all actions, + # so the output_dim is taken to be num_actions. + actor_network_instance = VanillaActorNetwork( + input_dim=env.observation_space.shape[0], + hidden_dims=[64, 64], + output_dim=num_actions, + ) + + # REINFORCE uses a VanillaValueNetwork by default + critic_network_instance = VanillaValueNetwork( + input_dim=env.observation_space.shape[0], + hidden_dims=[64, 64], + output_dim=1, + ) + + agent = PearlAgent( + policy_learner=REINFORCE( + state_dim=env.observation_space.shape[0], + action_space=env.action_space, + use_critic=True, + training_rounds=8, + batch_size=64, + action_representation_module=action_representation_module, + actor_network_instance=actor_network_instance, + critic_network_instance=critic_network_instance, + ), + replay_buffer=OnPolicyReplayBuffer(10_000), + ) + self.assertTrue( + target_return_is_reached( + target_return=500, + max_episodes=10_000, + agent=agent, + env=env, + learn=True, + learn_after_episode=True, + exploit=False, + ) + ) + def test_dueling_dqn( self, batch_size: int = 128, @@ -351,6 +416,7 @@ def test_ppo(self) -> None: env.observation_space.shape[0], env.action_space, actor_hidden_dims=[64, 64], + use_critic=True, critic_hidden_dims=[64, 64], training_rounds=20, batch_size=32, @@ -374,6 +440,63 @@ def test_ppo(self) -> None: ) ) + def test_ppo_network_instance(self) -> None: + """ + This test checks for performance of PPO when instances of actor and critic networks are + passed as input arguments. The performance metric is if PPO can eventually attain an + episodic return of 500. + + Note: Pearl currently only supports PPO for discrete action spaces. + """ + env = GymEnvironment("CartPole-v1") + assert isinstance(env.action_space, DiscreteActionSpace) + num_actions = env.action_space.n + action_representation_module = OneHotActionTensorRepresentationModule( + max_number_actions=num_actions + ) + + # VanillaActorNetwork outputs a probability distribution over all actions, + # so the output_dim is taken to be num_actions. + actor_network_instance = VanillaActorNetwork( + input_dim=env.observation_space.shape[0], + hidden_dims=[64, 64], + output_dim=num_actions, + ) + + # PPO uses a VanillaValueNetwork by default + critic_network_instance = VanillaValueNetwork( + input_dim=env.observation_space.shape[0], + hidden_dims=[64, 64], + output_dim=1, + ) + + agent = PearlAgent( + policy_learner=ProximalPolicyOptimization( + state_dim=env.observation_space.shape[0], + action_space=env.action_space, + use_critic=True, + training_rounds=20, + batch_size=32, + epsilon=0.1, + action_representation_module=action_representation_module, + actor_network_instance=actor_network_instance, + critic_network_instance=critic_network_instance, + ), + replay_buffer=OnPolicyReplayBuffer(10_000), + ) + self.assertTrue( + target_return_is_reached( + agent=agent, + env=env, + target_return=500, + max_episodes=1000, + learn=True, + learn_every_k_steps=200, + learn_after_episode=False, + exploit=False, + ) + ) + def test_sac(self) -> None: """ This test is checking if SAC will eventually get to 500 return for CartPole-v1 @@ -403,7 +526,67 @@ def test_sac(self) -> None: agent=agent, env=env, target_return=500, - max_episodes=1_000, + max_episodes=1000, + learn=True, + learn_after_episode=True, + exploit=False, + ) + ) + + def test_sac_network_instance(self) -> None: + """ + This test checks the performance of discrete SAC when actor and critic network instances + are passed as arguments. The performance metric is if SAC is able to eventually get to + episodic return of 500 for CartPole-v1. + + In this test, we use a discrete policy network (which outputs action probabilities) + and twin critics. + """ + env = GymEnvironment("CartPole-v1") + assert isinstance(env.action_space, DiscreteActionSpace) + num_actions = env.action_space.n + action_representation_module = OneHotActionTensorRepresentationModule( + max_number_actions=num_actions + ) + + # VanillaActorNetwork outputs a probability distribution over all actions, + # so the output_dim is taken to be num_actions. + actor_network = VanillaActorNetwork( + input_dim=env.observation_space.shape[0], + hidden_dims=[64, 64, 64], + output_dim=num_actions, + ) + + # we use twin critics of the type VanillaQValueNetwork. + twin_critic_network = TwinCritic( + state_dim=env.observation_space.shape[0], + action_dim=action_representation_module.max_number_actions, + hidden_dims=[64, 64, 64], + network_type=VanillaQValueNetwork, + init_fn=init_weights, + ) + + agent = PearlAgent( + policy_learner=SoftActorCritic( + state_dim=env.observation_space.shape[0], + action_space=env.action_space, + training_rounds=100, + batch_size=100, + entropy_coef=0.1, + actor_learning_rate=0.0001, + critic_learning_rate=0.0003, + action_representation_module=action_representation_module, + actor_network_instance=actor_network, + critic_network_instance=twin_critic_network, + ), + replay_buffer=FIFOOffPolicyReplayBuffer(50000), + ) + self.assertTrue( + target_return_is_reached( + agent=agent, + env=env, + target_return=500, + max_episodes=1000, learn=True, learn_after_episode=True, exploit=False, @@ -443,6 +626,62 @@ def test_continuous_sac(self) -> None: ) ) + def test_continuous_sac_network_instance(self) -> None: + """ + This test checks the performance of continuous SAC when actor and critic network instances + are passed as arguments. The performance metric is if SAC is able to eventually get to + a moving average episodic return -250 or less for Pendulum-v1. + + This test uses a Gaussian policy network and twin critics. + """ + env = GymEnvironment("Pendulum-v1") + + # for continuous action spaces, Pearl currently only supports + # IdentityActionRepresentationModule as the action representation module. So, the output_dim + # argument of the GaussianActorNetwork is the same as the action space dimension. Also, + # the action_dim argument for critic networks is the same as the action space dimension. + actor_network_instance = GaussianActorNetwork( + input_dim=env.observation_space.shape[0], + hidden_dims=[64, 64], + output_dim=env.action_space.action_dim, + action_space=env.action_space, + ) + + # SAC uses twin critics of the type VanillaQValueNetwork by default. + twin_critic_network = TwinCritic( + state_dim=env.observation_space.shape[0], + action_dim=env.action_space.action_dim, + hidden_dims=[64, 64], + network_type=VanillaQValueNetwork, + init_fn=init_weights, + ) + + agent = PearlAgent( + policy_learner=ContinuousSoftActorCritic( + state_dim=env.observation_space.shape[0], + action_space=env.action_space, + training_rounds=50, + batch_size=100, + entropy_coef=0.1, + actor_learning_rate=0.001, + critic_learning_rate=0.001, + actor_network_instance=actor_network_instance, + critic_network_instance=twin_critic_network, + ), + replay_buffer=FIFOOffPolicyReplayBuffer(100000), + ) + self.assertTrue( + target_return_is_reached( + agent=agent, + env=env, + target_return=-250, + max_episodes=1500, + learn=True, + learn_after_episode=True, + exploit=False, + ) + ) + def test_cql_online(self) -> None: """ This test is checking if DQN with conservative loss will eventually get to 500 return for @@ -554,6 +793,71 @@ def test_td3(self) -> None: ) ) + def test_td3_network_instance(self) -> None: + """ + This test checks the performance of TD3 when instances of actor and critic networks are + passed as input arguments. The performance metric is if TD3 is able to eventually get to + a moving average episodic return of -250 or less for Pendulum-v1. + + This test uses a deterministic policy network and twin critics. + """ + env = GymEnvironment("Pendulum-v1") + + # Note: for continuous action spaces, Pearl currently only supports + # IdentityActionRepresentationModule as the action representation module. So, the output_dim + # argument of the VanillaContinuousActorNetwork is the same as the action space dimension. + # For this reason, the action_dim argument for critic networks is the same as the action + # space dimension. + + # td3 uses a deterministic policy network (e.g. VanillaContinuousActorNetwork) by default. + actor_network_instance = VanillaContinuousActorNetwork( + input_dim=env.observation_space.shape[0], + hidden_dims=[400, 300], + output_dim=env.action_space.action_dim, + action_space=env.action_space, + ) + + twin_critic_network = TwinCritic( + state_dim=env.observation_space.shape[0], + action_dim=env.action_space.action_dim, + hidden_dims=[400, 300], + network_type=VanillaQValueNetwork, + init_fn=init_weights, + ) + + agent = PearlAgent( + policy_learner=TD3( + state_dim=env.observation_space.shape[0], + action_space=env.action_space, + actor_hidden_dims=[400, 300], + critic_hidden_dims=[400, 300], + critic_learning_rate=1e-2, + actor_learning_rate=1e-3, + training_rounds=5, + actor_soft_update_tau=0.05, + critic_soft_update_tau=0.05, + exploration_module=NormalDistributionExploration( + mean=0, + std_dev=0.2, + ), + actor_network_instance=actor_network_instance, + critic_network_instance=twin_critic_network, + ), + replay_buffer=FIFOOffPolicyReplayBuffer(50000), + ) + self.assertTrue( + target_return_is_reached( + agent=agent, + env=env, + target_return=-250, + max_episodes=1000, + learn=True, + learn_after_episode=True, + exploit=False, + check_moving_average=True, + ) + ) + def test_cql_offline_training(self) -> None: """ This test is checking if DQN with conservative loss will eventually get to > 50 return for diff --git a/test/unit/with_pytorch/test_gpu_usage.py b/test/unit/with_pytorch/test_gpu_usage.py index ab833fec..eb0b9f9f 100644 --- a/test/unit/with_pytorch/test_gpu_usage.py +++ b/test/unit/with_pytorch/test_gpu_usage.py @@ -30,7 +30,6 @@ from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning from pearl.utils.instantiations.environments.gym_environment import GymEnvironment -from pearl.utils.instantiations.spaces.box import BoxSpace from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace @@ -78,6 +77,7 @@ def test_pg_based_gpu_usage(self) -> None: state_dim=env.observation_space.shape[0], action_space=env.action_space, actor_hidden_dims=[64, 64], + use_critic=True, critic_hidden_dims=[64, 64], training_rounds=20, batch_size=500, diff --git a/test/unit/with_pytorch/test_ppo.py b/test/unit/with_pytorch/test_ppo.py index deefee62..004bc92a 100644 --- a/test/unit/with_pytorch/test_ppo.py +++ b/test/unit/with_pytorch/test_ppo.py @@ -30,6 +30,7 @@ def test_optimizer_param_count(self) -> None: 16, DiscreteActionSpace(actions=[torch.tensor(i) for i in range(3)]), actor_hidden_dims=[64, 64], + use_critic=True, critic_hidden_dims=[64, 64], training_rounds=1, batch_size=500, @@ -45,22 +46,6 @@ def test_optimizer_param_count(self) -> None: ) self.assertEqual(optimizer_params_count, model_params_count) - def test_training_round_setup(self) -> None: - """ - PPO inherit from PG and overwrite training_rounds - This test is to ensure it indeed overwrite - """ - policy_learner = ProximalPolicyOptimization( - 16, - DiscreteActionSpace(actions=[torch.tensor(i) for i in range(3)]), - actor_hidden_dims=[64, 64], - critic_hidden_dims=[64, 64], - training_rounds=10, - batch_size=500, - epsilon=0.1, - ) - self.assertEqual(10, policy_learner._training_rounds) - def test_preprocess_replay_buffer(self) -> None: """ PPO computes generalized advantage estimation and truncated lambda return @@ -72,6 +57,7 @@ def test_preprocess_replay_buffer(self) -> None: state_dim=state_dim, action_space=action_space, actor_hidden_dims=[64, 64], + use_critic=True, critic_hidden_dims=[64, 64], training_rounds=10, batch_size=500,