diff --git a/pearl/policy_learners/sequential_decision_making/deep_q_learning.py b/pearl/policy_learners/sequential_decision_making/deep_q_learning.py index 335a9634..3c8afefe 100644 --- a/pearl/policy_learners/sequential_decision_making/deep_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/deep_q_learning.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Any, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch from pearl.action_representation_modules.action_representation_module import ( @@ -94,8 +94,13 @@ class and uses `act` and `learn_batch` methods of that class. We only implement Note: This is an alternative to specifying a `network_type`. If provided, the specified `network_type` is ignored and the input `network_instance` is used for learning. Allows for custom implementations of Q-value networks. - """ + **kwargs: Additional arguments to be passed when using `TwoTowerNetwork` + class as the QValueNetwork. This includes {state_output_dim (int), + action_output_dim (int), state_hidden_dims (List[int]), + action_hidden_dims (List[int])}, all of which are used to instantiate a + `TwoTowerNetwork` object. + """ super(DeepQLearning, self).__init__( exploration_module=( exploration_module @@ -108,11 +113,14 @@ class and uses `act` and `learn_batch` methods of that class. We only implement hidden_dims=hidden_dims, learning_rate=learning_rate, soft_update_tau=soft_update_tau, + is_conservative=is_conservative, + conservative_alpha=conservative_alpha, network_type=network_type, action_representation_module=action_representation_module, discount_factor=discount_factor, training_rounds=training_rounds, batch_size=batch_size, + target_update_freq=target_update_freq, network_instance=network_instance, **kwargs, ) diff --git a/pearl/policy_learners/sequential_decision_making/deep_td_learning.py b/pearl/policy_learners/sequential_decision_making/deep_td_learning.py index 5903d37a..c95e60d7 100644 --- a/pearl/policy_learners/sequential_decision_making/deep_td_learning.py +++ b/pearl/policy_learners/sequential_decision_making/deep_td_learning.py @@ -63,7 +63,7 @@ def __init__( target_update_freq: int = 10, soft_update_tau: float = 0.1, is_conservative: bool = False, - conservative_alpha: float = 2.0, + conservative_alpha: Optional[float] = 2.0, network_type: Type[QValueNetwork] = VanillaQValueNetwork, state_output_dim: Optional[int] = None, action_output_dim: Optional[int] = None, @@ -71,7 +71,6 @@ def __init__( action_hidden_dims: Optional[List[int]] = None, network_instance: Optional[QValueNetwork] = None, action_representation_module: Optional[ActionRepresentationModule] = None, - **kwargs: Any, ) -> None: """Constructs a DeepTDLearning based policy learner. DeepTDLearning is the base class for all value based (i.e. temporal difference learning based) algorithms.