From c5fd64c88d8277b0f8e8f2fe8326e385bd945c46 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Jul 2024 08:24:22 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- docs/source/reference/objectives.rst | 4 ++++ torchrl/objectives/common.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 18cb6886914..1d92c390a4e 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -54,6 +54,10 @@ Since the calls to `vmap` are buried down the loss modules, TorchRL provides an interface to set that vmap mode from the outside through `loss.vmap_randomness = str_value`, see :meth:`~torchrl.objectives.LossModule.vmap_randomness` for more information. +``LossModule.vmap_randomness`` defaults to `"error"` if no random module is detected, and to `"different"` in +other cases. By default, only a limited number of modules are listed as random, but the list can be extended +using the :func:`~torchrl.objectives.common.add_random_module` function. + Training value functions ------------------------ diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 998eeaedc15..bce82286992 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -550,7 +550,12 @@ def vmap_randomness(self): If `"different"`, every element of the batch along which vmap is being called will behave differently. If `"same"`, vmaps will copy the same result across all elements. + ``vmap_randomness`` defaults to `"error"` if no random module is detected, and to `"different"` in + other cases. By default, only a limited number of modules are listed as random, but the list can be extended + using the :func:`~torchrl.objectives.common.add_random_module` function. + This property supports setting its value. + """ if self._vmap_randomness is None: do_break = False @@ -603,3 +608,9 @@ def __call__(self, x): x.data.clone() if self.clone else x.data, requires_grad=False ) return x.data.clone() if self.clone else x.data + + +def add_ramdom_module(module): + """Adds a random module to the list of modules that will be detected by :meth:`~torchrl.objectives.LossModule.vmap_randomness` as random.""" + global RANDOM_MODULE_LIST + RANDOM_MODULE_LIST = RANDOM_MODULE_LIST + (module,)