Skip to content

Commit

Permalink
[Minor,Feature] group_optimizers
Browse files Browse the repository at this point in the history
ghstack-source-id: 81a94ed641544a420bb1c455921ca6a17ecd6a22
Pull Request resolved: #2577
  • Loading branch information
vmoens committed Nov 18, 2024
1 parent 7bc84d1 commit 7829bd3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
9 changes: 5 additions & 4 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,12 @@ Utils
:toctree: generated/
:template: rl_template_noinherit.rst

HardUpdate
SoftUpdate
ValueEstimators
default_value_kwargs
distance_loss
group_optimizers
hold_out_net
hold_out_params
next_state_value
SoftUpdate
HardUpdate
ValueEstimators
default_value_kwargs
1 change: 1 addition & 0 deletions torchrl/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .utils import (
default_value_kwargs,
distance_loss,
group_optimizers,
HardUpdate,
hold_out_net,
hold_out_params,
Expand Down
18 changes: 18 additions & 0 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,21 @@ def _clip_value_loss(
# Chose the most pessimistic value prediction between clipped and non-clipped
loss_value = torch.max(loss_value, loss_value_clipped)
return loss_value, clip_fraction


def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer:
"""Groups multiple optimizers into a single one.
All optimizers are expected to have the same type.
"""
cls = None
params = []
for optimizer in optimizers:
if optimizer is None:
continue
if cls is None:
cls = type(optimizer)
if cls is not type(optimizer):
raise ValueError("Cannot group optimizers of different type.")
params.extend(optimizer.param_groups)
return cls(params)

0 comments on commit 7829bd3

Please sign in to comment.