Skip to content

Commit

Permalink
allow passing optimizers to policy learners
Browse files Browse the repository at this point in the history
Summary: allow passing optimizers to policy learners

Reviewed By: rodrigodesalvobraz

Differential Revision: D65919766

fbshipit-source-id: 2a05b8b78cf9dacc08511836d63cc0f7a55d20be
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 7, 2024
1 parent 8941dbe commit 6d107ec
Show file tree
Hide file tree
Showing 15 changed files with 147 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __init__(
actor_network_instance: ActorNetwork | None = None,
critic_network_instance: None
| (ValueNetwork | QValueNetwork | nn.Module) = None,
actor_optimizer: Optional[optim.Optimizer] = None,
critic_optimizer: Optional[optim.Optimizer] = None,
history_summarization_optimizer: Optional[optim.Optimizer] = None,
) -> None:
super().__init__(
on_policy=on_policy,
Expand Down Expand Up @@ -148,15 +151,19 @@ def __init__(
action_space=action_space,
)
self._actor.apply(init_weights)
self._actor_optimizer = optim.AdamW(
[
{
"params": self._actor.parameters(),
"lr": actor_learning_rate,
"amsgrad": True,
},
]
)
if actor_optimizer is not None:
self._actor_optimizer: optim.Optimizer = actor_optimizer
else:
# default actor optimizer
self._actor_optimizer = optim.AdamW(
[
{
"params": self._actor.parameters(),
"lr": actor_learning_rate,
"amsgrad": True,
},
]
)
self._actor_soft_update_tau = actor_soft_update_tau

# make a copy of the actor network to be used as the actor target network
Expand Down Expand Up @@ -189,20 +196,27 @@ def __init__(
network_type=critic_network_type,
)

self._critic_optimizer: optim.Optimizer = optim.AdamW(
[
{
"params": self._critic.parameters(),
"lr": critic_learning_rate,
"amsgrad": True,
},
]
)
if critic_optimizer is not None:
self._critic_optimizer: optim.Optimizer = critic_optimizer
else:
# default actor optimizer
self._critic_optimizer: optim.Optimizer = optim.AdamW(
[
{
"params": self._critic.parameters(),
"lr": critic_learning_rate,
"amsgrad": True,
},
]
)
if self._use_critic_target:
self._critic_target: nn.Module = copy.deepcopy(self._critic)

self._discount_factor = discount_factor
self._history_summarization_optimizer = history_summarization_optimizer
self._history_summarization_learning_rate = history_summarization_learning_rate
self._actor_learning_rate: float = self._actor_optimizer.param_groups[0]["lr"]
self._critic_learning_rate: float = self._critic_optimizer.param_groups[0]["lr"]

def set_history_summarization_module(
self, value: HistorySummarizationModule
Expand All @@ -212,15 +226,17 @@ def set_history_summarization_module(
"""
# pyre-fixme[16]: `ActorCriticBase` has no attribute
# `_history_summarization_optimizer`.
self._history_summarization_optimizer: optim.Optimizer = optim.AdamW(
[
{
"params": value.parameters(),
"lr": self._history_summarization_learning_rate,
"amsgrad": True,
}
]
)
if self._history_summarization_optimizer is None:
# default history summarization optimizer
self._history_summarization_optimizer: optim.Optimizer = optim.AdamW(
[
{
"params": value.parameters(),
"lr": self._history_summarization_learning_rate,
"amsgrad": True,
}
]
)
self._history_summarization_module = value

def act(
Expand Down Expand Up @@ -311,8 +327,7 @@ def learn_batch(self, batch: TransitionBatch) -> dict[str, Any]:
Dict[str, Any]: A dictionary containing the loss reports from the critic
and actor updates. These can be useful to track for debugging purposes.
"""
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `zero_grad`.
assert self._history_summarization_optimizer is not None
self._history_summarization_optimizer.zero_grad()
actor_loss = self._actor_loss(batch)
self._actor_optimizer.zero_grad()
Expand All @@ -332,7 +347,7 @@ def learn_batch(self, batch: TransitionBatch) -> dict[str, Any]:
critic_loss.backward()
self._critic_optimizer.step()
report["critic_loss"] = critic_loss.item()
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `step`.
assert self._history_summarization_optimizer is not None
self._history_summarization_optimizer.step()

if self._use_critic_target:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
target_update_freq: int = 10,
soft_update_tau: float = 1.0,
action_representation_module: ActionRepresentationModule | None = None,
optimizer: Optional[optim.Optimizer] = None,
) -> None:
assert isinstance(action_space, DiscreteActionSpace)
if action_representation_module is None:
Expand All @@ -89,9 +90,12 @@ def __init__(
self._soft_update_tau = soft_update_tau
self._Q = q_ensemble_network
self._Q_target: EnsembleQValueNetwork = deepcopy(self._Q)
self._optimizer = optim.AdamW(
self._Q.parameters(), lr=self._learning_rate, amsgrad=True
)
if optimizer is not None:
self._optimizer: optim.Optimizer = optimizer
else:
self._optimizer = optim.AdamW(
self._Q.parameters(), lr=self._learning_rate, amsgrad=True
)

@property
def ensemble_size(self) -> int:
Expand Down
8 changes: 7 additions & 1 deletion pearl/policy_learners/sequential_decision_making/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pearl.utils.functional_utils.learning.critic_utils import (
twin_critic_action_value_loss,
)
from torch import nn
from torch import nn, optim


class DeepDeterministicPolicyGradient(ActorCriticBase):
Expand Down Expand Up @@ -66,6 +66,9 @@ def __init__(
action_representation_module: ActionRepresentationModule | None = None,
actor_network_instance: ActorNetwork | None = None,
critic_network_instance: QValueNetwork | nn.Module | None = None,
actor_optimizer: Optional[optim.Optimizer] = None,
critic_optimizer: Optional[optim.Optimizer] = None,
history_summarization_optimizer: Optional[optim.Optimizer] = None,
) -> None:
super().__init__(
state_dim=state_dim,
Expand Down Expand Up @@ -95,6 +98,9 @@ def __init__(
action_representation_module=action_representation_module,
actor_network_instance=actor_network_instance,
critic_network_instance=critic_network_instance,
actor_optimizer=actor_optimizer,
critic_optimizer=critic_optimizer,
history_summarization_optimizer=history_summarization_optimizer,
)

def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
from torch import optim


class DeepQLearning(DeepTDLearning):
Expand All @@ -54,6 +55,7 @@ def __init__(
network_type: type[QValueNetwork] = VanillaQValueNetwork,
action_representation_module: ActionRepresentationModule | None = None,
network_instance: QValueNetwork | None = None,
optimizer: Optional[optim.Optimizer] = None,
**kwargs: Any,
) -> None:
"""Constructs a DeepQLearning policy learner. DeepQLearning is based on DeepTDLearning
Expand Down Expand Up @@ -122,6 +124,7 @@ class as the QValueNetwork. This includes {state_output_dim (int),
batch_size=batch_size,
target_update_freq=target_update_freq,
network_instance=network_instance,
optimizer=optimizer,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DeepTDLearning,
)
from pearl.replay_buffers.transition import TransitionBatch
from torch import optim


class DeepSARSA(DeepTDLearning):
Expand All @@ -38,6 +39,7 @@ def __init__(
action_space: ActionSpace | None = None,
exploration_module: ExplorationModule | None = None,
action_representation_module: ActionRepresentationModule | None = None,
optimizer: Optional[optim.Optimizer] = None,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -50,6 +52,7 @@ def __init__(
),
on_policy=True,
action_representation_module=action_representation_module,
optimizer=optimizer,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(
action_hidden_dims: list[int] | None = None,
network_instance: QValueNetwork | None = None,
action_representation_module: ActionRepresentationModule | None = None,
optimizer: Optional[optim.Optimizer] = None,
**kwargs: Any,
) -> None:
"""Constructs a DeepTDLearning based policy learner. DeepTDLearning is the base class
for all value based (i.e. temporal difference learning based) algorithms.
Expand Down Expand Up @@ -165,9 +167,12 @@ def make_specified_network() -> QValueNetwork:
self._Q = make_specified_network()

self._Q_target: QValueNetwork = copy.deepcopy(self._Q)
self._optimizer: torch.optim.Optimizer = optim.AdamW(
self._Q.parameters(), lr=learning_rate, amsgrad=True
)
if optimizer is not None:
self._optimizer = optimizer
else:
self._optimizer: torch.optim.Optimizer = optim.AdamW(
self._Q.parameters(), lr=learning_rate, amsgrad=True
)

@property
def optimizer(self) -> torch.optim.Optimizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def __init__(
actor_network_instance: ActorNetwork | None = None,
critic_network_instance: QValueNetwork | None = None,
value_network_instance: ValueNetwork | None = None,
actor_optimizer: Optional[optim.Optimizer] = None,
critic_optimizer: Optional[optim.Optimizer] = None,
history_summarization_optimizer: Optional[optim.Optimizer] = None,
value_optimizer: Optional[optim.Optimizer] = None,
) -> None:
super().__init__(
state_dim=state_dim,
Expand Down Expand Up @@ -127,6 +131,9 @@ def __init__(
action_representation_module=action_representation_module,
actor_network_instance=actor_network_instance,
critic_network_instance=critic_network_instance,
actor_optimizer=actor_optimizer,
critic_optimizer=critic_optimizer,
history_summarization_optimizer=history_summarization_optimizer,
)

self._expectile = expectile
Expand All @@ -147,11 +154,14 @@ def __init__(
hidden_dims=value_critic_hidden_dims,
output_dim=1,
)
self._value_network_optimizer = optim.AdamW(
self._value_network.parameters(),
lr=value_critic_learning_rate,
amsgrad=True,
)
if value_optimizer is not None:
self._value_network_optimizer: optim.Optimizer = value_optimizer
else:
self._value_network_optimizer = optim.AdamW(
self._value_network.parameters(),
lr=value_critic_learning_rate,
amsgrad=True,
)

def learn_batch(self, batch: TransitionBatch) -> dict[str, Any]:
value_loss = self._value_loss(batch)
Expand Down
8 changes: 7 additions & 1 deletion pearl/policy_learners/sequential_decision_making/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from pearl.utils.replay_buffer_utils import (
make_replay_buffer_class_for_specific_transition_types,
)
from torch import nn
from torch import nn, optim


@dataclass(frozen=False)
Expand Down Expand Up @@ -110,6 +110,9 @@ def __init__(
action_representation_module: ActionRepresentationModule | None = None,
actor_network_instance: ActorNetwork | None = None,
critic_network_instance: ValueNetwork | nn.Module | None = None,
actor_optimizer: Optional[optim.Optimizer] = None,
critic_optimizer: Optional[optim.Optimizer] = None,
history_summarization_optimizer: Optional[optim.Optimizer] = None,
) -> None:
super().__init__(
state_dim=state_dim,
Expand Down Expand Up @@ -140,6 +143,9 @@ def __init__(
action_representation_module=action_representation_module,
actor_network_instance=actor_network_instance,
critic_network_instance=critic_network_instance,
actor_optimizer=actor_optimizer,
critic_optimizer=critic_optimizer,
history_summarization_optimizer=history_summarization_optimizer,
)
self._epsilon = epsilon
self._trace_decay_param = trace_decay_param
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
from torch import optim


class QuantileRegressionDeepQLearning(QuantileRegressionDeepTDLearning):
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
action_representation_module: ActionRepresentationModule | None = None,
network_type: type[QuantileQValueNetwork] = QuantileQValueNetwork,
network_instance: QuantileQValueNetwork | None = None,
optimizer: Optional[optim.Optimizer] = None,
) -> None:
assert isinstance(action_space, DiscreteActionSpace)
super().__init__(
Expand All @@ -82,6 +84,7 @@ def __init__(
network_type=network_type,
network_instance=network_instance,
action_representation_module=action_representation_module,
optimizer=optimizer,
)

# QR-DQN is based on QuantileRegressionDeepTDLearning class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
] = QuantileQValueNetwork, # C51 might use a different network type; add that later
network_instance: QuantileQValueNetwork | None = None,
action_representation_module: ActionRepresentationModule | None = None,
optimizer: Optional[optim.Optimizer] = None,
) -> None:
assert isinstance(action_space, DiscreteActionSpace)
super().__init__(
Expand All @@ -76,6 +77,7 @@ def __init__(
on_policy=on_policy,
is_action_continuous=False,
action_representation_module=action_representation_module,
optimizer=optimizer,
)

if hidden_dims is None:
Expand Down Expand Up @@ -111,9 +113,12 @@ def make_specified_network() -> QuantileQValueNetwork:
self._Q: QuantileQValueNetwork = make_specified_network()

self._Q_target: QuantileQValueNetwork = copy.deepcopy(self._Q)
self._optimizer = optim.AdamW(
self._Q.parameters(), lr=learning_rate, amsgrad=True
)
if optimizer is not None:
self._optimizer: optim.Optimizer = optimizer
else:
self._optimizer = optim.AdamW(
self._Q.parameters(), lr=learning_rate, amsgrad=True
)

def set_history_summarization_module(
self, value: HistorySummarizationModule
Expand Down
Loading

0 comments on commit 6d107ec

Please sign in to comment.