Skip to content

Commit

Permalink
Add network instance attributes to several policy learners
Browse files Browse the repository at this point in the history
Summary: Currently some policy learners do not take network instance as input. Add them here.

Reviewed By: rodrigodesalvobraz

Differential Revision: D65839046

fbshipit-source-id: b93b7d995536b5a3db03a86fdf3d3e877f4ba6bd
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Nov 13, 2024
1 parent 4687b47 commit 3ade5e5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def __init__(
self,
state_dim: int,
action_space: ActionSpace,
actor_hidden_dims: List[int],
critic_hidden_dims: List[int],
value_critic_hidden_dims: List[int],
actor_hidden_dims: Optional[List[int]] = None,
critic_hidden_dims: Optional[List[int]] = None,
value_critic_hidden_dims: Optional[List[int]] = None,
exploration_module: Optional[ExplorationModule] = None,
actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
Expand All @@ -97,9 +97,8 @@ def __init__(
advantage_clamp: float = 100.0,
action_representation_module: Optional[ActionRepresentationModule] = None,
actor_network_instance: Optional[ActorNetwork] = None,
critic_network_instance: Optional[
Union[ValueNetwork, QValueNetwork, torch.nn.Module]
] = None,
critic_network_instance: Optional[QValueNetwork] = None,
value_network_instance: Optional[ValueNetwork] = None,
) -> None:
super(ImplicitQLearning, self).__init__(
state_dim=state_dim,
Expand Down Expand Up @@ -138,11 +137,14 @@ def __init__(
)
self._advantage_clamp = advantage_clamp
# iql uses both q and v approximators
self._value_network: ValueNetwork = value_network_type(
input_dim=state_dim,
hidden_dims=value_critic_hidden_dims,
output_dim=1,
)
if value_network_instance is not None:
self._value_network = value_network_instance
else:
self._value_network: ValueNetwork = value_network_type(
input_dim=state_dim,
hidden_dims=value_critic_hidden_dims,
output_dim=1,
)
self._value_network_optimizer = optim.AdamW(
self._value_network.parameters(),
lr=value_critic_learning_rate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import List, Optional
from typing import List, Optional, Type

import torch
from pearl.action_representation_modules.action_representation_module import (
Expand Down Expand Up @@ -58,6 +58,8 @@ def __init__(
target_update_freq: int = 10,
soft_update_tau: float = 0.05,
action_representation_module: Optional[ActionRepresentationModule] = None,
network_type: Type[QuantileQValueNetwork] = QuantileQValueNetwork,
network_instance: Optional[QuantileQValueNetwork] = None,
) -> None:
assert isinstance(action_space, DiscreteActionSpace)
super(QuantileRegressionDeepQLearning, self).__init__(
Expand All @@ -77,7 +79,8 @@ def __init__(
batch_size=batch_size,
target_update_freq=target_update_freq,
soft_update_tau=soft_update_tau,
network_type=QuantileQValueNetwork, # enforced to be of the type QuantileQValueNetwork
network_type=network_type,
network_instance=network_instance,
action_representation_module=action_representation_module,
)

Expand Down

0 comments on commit 3ade5e5

Please sign in to comment.