From 3ade5e504a262de7d96f52de467a5545d20ca083 Mon Sep 17 00:00:00 2001 From: Yi Wan Date: Wed, 13 Nov 2024 15:09:33 -0800 Subject: [PATCH] Add network instance attributes to several policy learners Summary: Currently some policy learners do not take network instance as input. Add them here. Reviewed By: rodrigodesalvobraz Differential Revision: D65839046 fbshipit-source-id: b93b7d995536b5a3db03a86fdf3d3e877f4ba6bd --- .../implicit_q_learning.py | 24 ++++++++++--------- .../quantile_regression_deep_q_learning.py | 7 ++++-- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py index 29b5d0e3..c4b5c88e 100644 --- a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py @@ -77,9 +77,9 @@ def __init__( self, state_dim: int, action_space: ActionSpace, - actor_hidden_dims: List[int], - critic_hidden_dims: List[int], - value_critic_hidden_dims: List[int], + actor_hidden_dims: Optional[List[int]] = None, + critic_hidden_dims: Optional[List[int]] = None, + value_critic_hidden_dims: Optional[List[int]] = None, exploration_module: Optional[ExplorationModule] = None, actor_network_type: Type[ActorNetwork] = VanillaActorNetwork, critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork, @@ -97,9 +97,8 @@ def __init__( advantage_clamp: float = 100.0, action_representation_module: Optional[ActionRepresentationModule] = None, actor_network_instance: Optional[ActorNetwork] = None, - critic_network_instance: Optional[ - Union[ValueNetwork, QValueNetwork, torch.nn.Module] - ] = None, + critic_network_instance: Optional[QValueNetwork] = None, + value_network_instance: Optional[ValueNetwork] = None, ) -> None: super(ImplicitQLearning, self).__init__( state_dim=state_dim, @@ -138,11 +137,14 @@ def __init__( ) self._advantage_clamp = advantage_clamp # iql uses both q and v approximators - self._value_network: ValueNetwork = value_network_type( - input_dim=state_dim, - hidden_dims=value_critic_hidden_dims, - output_dim=1, - ) + if value_network_instance is not None: + self._value_network = value_network_instance + else: + self._value_network: ValueNetwork = value_network_type( + input_dim=state_dim, + hidden_dims=value_critic_hidden_dims, + output_dim=1, + ) self._value_network_optimizer = optim.AdamW( self._value_network.parameters(), lr=value_critic_learning_rate, diff --git a/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_q_learning.py b/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_q_learning.py index 3affbc34..e8a2b587 100644 --- a/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_q_learning.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import List, Optional +from typing import List, Optional, Type import torch from pearl.action_representation_modules.action_representation_module import ( @@ -58,6 +58,8 @@ def __init__( target_update_freq: int = 10, soft_update_tau: float = 0.05, action_representation_module: Optional[ActionRepresentationModule] = None, + network_type: Type[QuantileQValueNetwork] = QuantileQValueNetwork, + network_instance: Optional[QuantileQValueNetwork] = None, ) -> None: assert isinstance(action_space, DiscreteActionSpace) super(QuantileRegressionDeepQLearning, self).__init__( @@ -77,7 +79,8 @@ def __init__( batch_size=batch_size, target_update_freq=target_update_freq, soft_update_tau=soft_update_tau, - network_type=QuantileQValueNetwork, # enforced to be of the type QuantileQValueNetwork + network_type=network_type, + network_instance=network_instance, action_representation_module=action_representation_module, )