From 7829bd3f3be166eabf0b3fb53e9dc8fe509ccc79 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 15:13:03 +0000 Subject: [PATCH] [Minor,Feature] `group_optimizers` ghstack-source-id: 81a94ed641544a420bb1c455921ca6a17ecd6a22 Pull Request resolved: https://github.com/pytorch/rl/pull/2577 --- docs/source/reference/objectives.rst | 9 +++++---- torchrl/objectives/__init__.py | 1 + torchrl/objectives/utils.py | 18 ++++++++++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 98b282767cc..9e7df1bff8f 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -311,11 +311,12 @@ Utils :toctree: generated/ :template: rl_template_noinherit.rst + HardUpdate + SoftUpdate + ValueEstimators + default_value_kwargs distance_loss + group_optimizers hold_out_net hold_out_params next_state_value - SoftUpdate - HardUpdate - ValueEstimators - default_value_kwargs diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 1ea9ebb5998..01f993e629a 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -23,6 +23,7 @@ from .utils import ( default_value_kwargs, distance_loss, + group_optimizers, HardUpdate, hold_out_net, hold_out_params, diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index f47b85075b4..17ab16cfefa 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -590,3 +590,21 @@ def _clip_value_loss( # Chose the most pessimistic value prediction between clipped and non-clipped loss_value = torch.max(loss_value, loss_value_clipped) return loss_value, clip_fraction + + +def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer: + """Groups multiple optimizers into a single one. + + All optimizers are expected to have the same type. + """ + cls = None + params = [] + for optimizer in optimizers: + if optimizer is None: + continue + if cls is None: + cls = type(optimizer) + if cls is not type(optimizer): + raise ValueError("Cannot group optimizers of different type.") + params.extend(optimizer.param_groups) + return cls(params)