Skip to content

Commit

Permalink
make state_dim optional
Browse files Browse the repository at this point in the history
Summary: Given that we allow taking network instances, state_dim should be optional.

Reviewed By: rodrigodesalvobraz

Differential Revision: D65923071

fbshipit-source-id: c0f9922a2f05ba1a6455517d3fbfdef6804cf7e6
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Nov 21, 2024
1 parent 2f16709 commit 96c527b
Show file tree
Hide file tree
Showing 12 changed files with 28 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class ActorCriticBase(PolicyLearner):

def __init__(
self,
state_dim: int,
exploration_module: ExplorationModule,
state_dim: Optional[int] = None,
actor_hidden_dims: Optional[List[int]] = None,
use_critic: bool = True,
critic_hidden_dims: Optional[List[int]] = None,
Expand Down Expand Up @@ -116,6 +116,10 @@ def __init__(
if actor_network_instance is not None:
self._actor: nn.Module = actor_network_instance
else:
assert (
state_dim is not None
), f"{self.__class__.__name__} requires parameter state_dim if a parameter \
action_network_instance has not been provided."
assert (
actor_hidden_dims is not None
), f"{self.__class__.__name__} requires parameter actor_hidden_dims if a parameter \
Expand Down Expand Up @@ -161,6 +165,10 @@ def __init__(
if critic_network_instance is not None:
self._critic: nn.Module = critic_network_instance
else:
assert (
state_dim is not None
), f"{self.__class__.__name__} requires parameter state_dim if a parameter \
critic_network_instance has not been provided."
assert (
critic_hidden_dims is not None
), f"{self.__class__.__name__} requires parameter critic_hidden_dims if a \
Expand Down
2 changes: 1 addition & 1 deletion pearl/policy_learners/sequential_decision_making/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class DeepDeterministicPolicyGradient(ActorCriticBase):

def __init__(
self,
state_dim: int,
action_space: ActionSpace,
state_dim: Optional[int] = None,
actor_hidden_dims: Optional[List[int]] = None,
critic_hidden_dims: Optional[List[int]] = None,
exploration_module: Optional[ExplorationModule] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class DeepQLearning(DeepTDLearning):

def __init__(
self,
state_dim: int,
action_space: Optional[ActionSpace] = None,
hidden_dims: Optional[List[int]] = None,
exploration_module: Optional[ExplorationModule] = None,
Expand All @@ -51,6 +50,7 @@ def __init__(
soft_update_tau: float = 0.75, # a value of 1 indicates no soft updates
is_conservative: bool = False,
conservative_alpha: Optional[float] = 2.0,
state_dim: Optional[int] = None,
network_type: Type[QValueNetwork] = VanillaQValueNetwork,
action_representation_module: Optional[ActionRepresentationModule] = None,
network_instance: Optional[QValueNetwork] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DeepSARSA(DeepTDLearning):

def __init__(
self,
state_dim: int,
state_dim: Optional[int] = None,
action_space: Optional[ActionSpace] = None,
exploration_module: Optional[ExplorationModule] = None,
action_representation_module: Optional[ActionRepresentationModule] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ class DeepTDLearning(PolicyLearner):

def __init__(
self,
state_dim: int,
exploration_module: ExplorationModule,
on_policy: bool,
state_dim: Optional[int] = None,
action_space: Optional[ActionSpace] = None,
hidden_dims: Optional[List[int]] = None,
learning_rate: float = 0.001,
Expand Down Expand Up @@ -130,6 +130,7 @@ def __init__(
self._conservative_alpha = conservative_alpha

def make_specified_network() -> QValueNetwork:
assert state_dim is not None
assert hidden_dims is not None
if network_type is TwoTowerQValueNetwork:
return network_type(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ class ImplicitQLearning(ActorCriticBase):

def __init__(
self,
state_dim: int,
action_space: ActionSpace,
state_dim: Optional[int] = None,
actor_hidden_dims: Optional[List[int]] = None,
critic_hidden_dims: Optional[List[int]] = None,
value_critic_hidden_dims: Optional[List[int]] = None,
Expand Down Expand Up @@ -140,6 +140,8 @@ def __init__(
if value_network_instance is not None:
self._value_network = value_network_instance
else:
assert state_dim is not None
assert value_critic_hidden_dims is not None
self._value_network: ValueNetwork = value_network_type(
input_dim=state_dim,
hidden_dims=value_critic_hidden_dims,
Expand Down
2 changes: 1 addition & 1 deletion pearl/policy_learners/sequential_decision_making/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ class ProximalPolicyOptimization(ActorCriticBase):

def __init__(
self,
state_dim: int,
action_space: ActionSpace,
state_dim: Optional[int] = None,
actor_hidden_dims: Optional[List[int]] = None,
critic_hidden_dims: Optional[List[int]] = None,
actor_learning_rate: float = 1e-4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class SoftActorCritic(ActorCriticBase):

def __init__(
self,
state_dim: int,
action_space: ActionSpace,
state_dim: Optional[int] = None,
actor_hidden_dims: Optional[List[int]] = None,
critic_hidden_dims: Optional[List[int]] = None,
actor_learning_rate: float = 1e-4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class ContinuousSoftActorCritic(ActorCriticBase):

def __init__(
self,
state_dim: int,
action_space: ActionSpace,
state_dim: Optional[int] = None,
actor_hidden_dims: Optional[List[int]] = None,
critic_hidden_dims: Optional[List[int]] = None,
actor_learning_rate: float = 1e-3,
Expand Down
4 changes: 2 additions & 2 deletions pearl/policy_learners/sequential_decision_making/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class TD3(DeepDeterministicPolicyGradient):

def __init__(
self,
state_dim: int,
action_space: ActionSpace,
state_dim: Optional[int] = None,
actor_hidden_dims: Optional[List[int]] = None,
critic_hidden_dims: Optional[List[int]] = None,
exploration_module: Optional[ExplorationModule] = None,
Expand Down Expand Up @@ -200,9 +200,9 @@ class TD3BC(TD3):

def __init__(
self,
state_dim: int,
action_space: ActionSpace,
behavior_policy: torch.nn.Module,
state_dim: Optional[int] = None,
actor_hidden_dims: Optional[List[int]] = None,
critic_hidden_dims: Optional[List[int]] = None,
exploration_module: Optional[ExplorationModule] = None,
Expand Down
4 changes: 2 additions & 2 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ def test_ppo(self) -> None:
num_actions = env.action_space.n
agent = PearlAgent(
policy_learner=ProximalPolicyOptimization(
env.observation_space.shape[0],
env.action_space,
action_space=env.action_space,
state_dim=env.observation_space.shape[0],
actor_hidden_dims=[64, 64],
critic_hidden_dims=[64, 64],
training_rounds=20,
Expand Down
6 changes: 4 additions & 2 deletions test/unit/with_pytorch/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def test_optimizer_param_count(self) -> None:
including actor and critic
"""
policy_learner = ProximalPolicyOptimization(
16,
DiscreteActionSpace(actions=[torch.tensor(i) for i in range(3)]),
action_space=DiscreteActionSpace(
actions=[torch.tensor(i) for i in range(3)]
),
state_dim=16,
actor_hidden_dims=[64, 64],
critic_hidden_dims=[64, 64],
training_rounds=1,
Expand Down

0 comments on commit 96c527b

Please sign in to comment.