diff --git a/pearl/policy_learners/sequential_decision_making/actor_critic_base.py b/pearl/policy_learners/sequential_decision_making/actor_critic_base.py index d11810b..85e8b77 100644 --- a/pearl/policy_learners/sequential_decision_making/actor_critic_base.py +++ b/pearl/policy_learners/sequential_decision_making/actor_critic_base.py @@ -64,8 +64,8 @@ class ActorCriticBase(PolicyLearner): def __init__( self, - state_dim: int, exploration_module: ExplorationModule, + state_dim: Optional[int] = None, actor_hidden_dims: Optional[List[int]] = None, use_critic: bool = True, critic_hidden_dims: Optional[List[int]] = None, @@ -116,6 +116,10 @@ def __init__( if actor_network_instance is not None: self._actor: nn.Module = actor_network_instance else: + assert ( + state_dim is not None + ), f"{self.__class__.__name__} requires parameter state_dim if a parameter \ + action_network_instance has not been provided." assert ( actor_hidden_dims is not None ), f"{self.__class__.__name__} requires parameter actor_hidden_dims if a parameter \ @@ -161,6 +165,10 @@ def __init__( if critic_network_instance is not None: self._critic: nn.Module = critic_network_instance else: + assert ( + state_dim is not None + ), f"{self.__class__.__name__} requires parameter state_dim if a parameter \ + critic_network_instance has not been provided." assert ( critic_hidden_dims is not None ), f"{self.__class__.__name__} requires parameter critic_hidden_dims if a \ diff --git a/pearl/policy_learners/sequential_decision_making/ddpg.py b/pearl/policy_learners/sequential_decision_making/ddpg.py index 2f867a4..e3c00aa 100644 --- a/pearl/policy_learners/sequential_decision_making/ddpg.py +++ b/pearl/policy_learners/sequential_decision_making/ddpg.py @@ -48,8 +48,8 @@ class DeepDeterministicPolicyGradient(ActorCriticBase): def __init__( self, - state_dim: int, action_space: ActionSpace, + state_dim: Optional[int] = None, actor_hidden_dims: Optional[List[int]] = None, critic_hidden_dims: Optional[List[int]] = None, exploration_module: Optional[ExplorationModule] = None, diff --git a/pearl/policy_learners/sequential_decision_making/deep_q_learning.py b/pearl/policy_learners/sequential_decision_making/deep_q_learning.py index 72cbc04..6c8164d 100644 --- a/pearl/policy_learners/sequential_decision_making/deep_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/deep_q_learning.py @@ -39,7 +39,6 @@ class DeepQLearning(DeepTDLearning): def __init__( self, - state_dim: int, action_space: Optional[ActionSpace] = None, hidden_dims: Optional[List[int]] = None, exploration_module: Optional[ExplorationModule] = None, @@ -51,6 +50,7 @@ def __init__( soft_update_tau: float = 0.75, # a value of 1 indicates no soft updates is_conservative: bool = False, conservative_alpha: Optional[float] = 2.0, + state_dim: Optional[int] = None, network_type: Type[QValueNetwork] = VanillaQValueNetwork, action_representation_module: Optional[ActionRepresentationModule] = None, network_instance: Optional[QValueNetwork] = None, diff --git a/pearl/policy_learners/sequential_decision_making/deep_sarsa.py b/pearl/policy_learners/sequential_decision_making/deep_sarsa.py index ebbada1..b3eaab0 100644 --- a/pearl/policy_learners/sequential_decision_making/deep_sarsa.py +++ b/pearl/policy_learners/sequential_decision_making/deep_sarsa.py @@ -34,7 +34,7 @@ class DeepSARSA(DeepTDLearning): def __init__( self, - state_dim: int, + state_dim: Optional[int] = None, action_space: Optional[ActionSpace] = None, exploration_module: Optional[ExplorationModule] = None, action_representation_module: Optional[ActionRepresentationModule] = None, diff --git a/pearl/policy_learners/sequential_decision_making/deep_td_learning.py b/pearl/policy_learners/sequential_decision_making/deep_td_learning.py index 5018061..1276afe 100644 --- a/pearl/policy_learners/sequential_decision_making/deep_td_learning.py +++ b/pearl/policy_learners/sequential_decision_making/deep_td_learning.py @@ -51,9 +51,9 @@ class DeepTDLearning(PolicyLearner): def __init__( self, - state_dim: int, exploration_module: ExplorationModule, on_policy: bool, + state_dim: Optional[int] = None, action_space: Optional[ActionSpace] = None, hidden_dims: Optional[List[int]] = None, learning_rate: float = 0.001, @@ -130,6 +130,7 @@ def __init__( self._conservative_alpha = conservative_alpha def make_specified_network() -> QValueNetwork: + assert state_dim is not None assert hidden_dims is not None if network_type is TwoTowerQValueNetwork: return network_type( diff --git a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py index 7cb373a..58b1c18 100644 --- a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py @@ -75,8 +75,8 @@ class ImplicitQLearning(ActorCriticBase): def __init__( self, - state_dim: int, action_space: ActionSpace, + state_dim: Optional[int] = None, actor_hidden_dims: Optional[List[int]] = None, critic_hidden_dims: Optional[List[int]] = None, value_critic_hidden_dims: Optional[List[int]] = None, @@ -140,6 +140,8 @@ def __init__( if value_network_instance is not None: self._value_network = value_network_instance else: + assert state_dim is not None + assert value_critic_hidden_dims is not None self._value_network: ValueNetwork = value_network_type( input_dim=state_dim, hidden_dims=value_critic_hidden_dims, diff --git a/pearl/policy_learners/sequential_decision_making/ppo.py b/pearl/policy_learners/sequential_decision_making/ppo.py index 2f3097b..0f9642a 100644 --- a/pearl/policy_learners/sequential_decision_making/ppo.py +++ b/pearl/policy_learners/sequential_decision_making/ppo.py @@ -91,8 +91,8 @@ class ProximalPolicyOptimization(ActorCriticBase): def __init__( self, - state_dim: int, action_space: ActionSpace, + state_dim: Optional[int] = None, actor_hidden_dims: Optional[List[int]] = None, critic_hidden_dims: Optional[List[int]] = None, actor_learning_rate: float = 1e-4, diff --git a/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py b/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py index 49b200e..cfa8734 100644 --- a/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py +++ b/pearl/policy_learners/sequential_decision_making/soft_actor_critic.py @@ -53,8 +53,8 @@ class SoftActorCritic(ActorCriticBase): def __init__( self, - state_dim: int, action_space: ActionSpace, + state_dim: Optional[int] = None, actor_hidden_dims: Optional[List[int]] = None, critic_hidden_dims: Optional[List[int]] = None, actor_learning_rate: float = 1e-4, diff --git a/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py b/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py index d5440f2..833eb1c 100644 --- a/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py +++ b/pearl/policy_learners/sequential_decision_making/soft_actor_critic_continuous.py @@ -46,8 +46,8 @@ class ContinuousSoftActorCritic(ActorCriticBase): def __init__( self, - state_dim: int, action_space: ActionSpace, + state_dim: Optional[int] = None, actor_hidden_dims: Optional[List[int]] = None, critic_hidden_dims: Optional[List[int]] = None, actor_learning_rate: float = 1e-3, diff --git a/pearl/policy_learners/sequential_decision_making/td3.py b/pearl/policy_learners/sequential_decision_making/td3.py index 3286dd0..da8bd31 100644 --- a/pearl/policy_learners/sequential_decision_making/td3.py +++ b/pearl/policy_learners/sequential_decision_making/td3.py @@ -48,8 +48,8 @@ class TD3(DeepDeterministicPolicyGradient): def __init__( self, - state_dim: int, action_space: ActionSpace, + state_dim: Optional[int] = None, actor_hidden_dims: Optional[List[int]] = None, critic_hidden_dims: Optional[List[int]] = None, exploration_module: Optional[ExplorationModule] = None, @@ -200,9 +200,9 @@ class TD3BC(TD3): def __init__( self, - state_dim: int, action_space: ActionSpace, behavior_policy: torch.nn.Module, + state_dim: Optional[int] = None, actor_hidden_dims: Optional[List[int]] = None, critic_hidden_dims: Optional[List[int]] = None, exploration_module: Optional[ExplorationModule] = None, diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index e1852af..a420342 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -412,8 +412,8 @@ def test_ppo(self) -> None: num_actions = env.action_space.n agent = PearlAgent( policy_learner=ProximalPolicyOptimization( - env.observation_space.shape[0], - env.action_space, + action_space=env.action_space, + state_dim=env.observation_space.shape[0], actor_hidden_dims=[64, 64], critic_hidden_dims=[64, 64], training_rounds=20, diff --git a/test/unit/with_pytorch/test_ppo.py b/test/unit/with_pytorch/test_ppo.py index fc2b871..909dd0a 100644 --- a/test/unit/with_pytorch/test_ppo.py +++ b/test/unit/with_pytorch/test_ppo.py @@ -26,8 +26,10 @@ def test_optimizer_param_count(self) -> None: including actor and critic """ policy_learner = ProximalPolicyOptimization( - 16, - DiscreteActionSpace(actions=[torch.tensor(i) for i in range(3)]), + action_space=DiscreteActionSpace( + actions=[torch.tensor(i) for i in range(3)] + ), + state_dim=16, actor_hidden_dims=[64, 64], critic_hidden_dims=[64, 64], training_rounds=1,