From 447970d59b06ed3127786ea0b262b03c94e93878 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Jul 2024 08:59:10 +0100 Subject: [PATCH] [Doc] Add doc about vmap randomness ghstack-source-id: f755100911ed5271de39a9500e1ffd69754dfad5 Pull Request resolved: https://github.com/pytorch/rl/pull/2316 --- docs/source/reference/objectives.rst | 16 ++++++++++++++++ torchrl/objectives/common.py | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 537d4542910..18cb6886914 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -38,6 +38,22 @@ The main characteristics of TorchRL losses are: If the parameters are modified out-of-place, :meth:`~torchrl.objectives.LossModule.from_stateful_net` can be used to reset the parameters in the loss to the new value. +torch.vmap and randomness +------------------------- + +TorchRL loss modules have plenty of calls to :func:`~torch.vmap` to amortize the cost of calling multiple similar models +in a loop, and instead vectorize these operations. `vmap` needs to be told explicitly what to do when random numbers +need to be generated within the call. To do this, a randomness mode need to be set and must be one of `"error"` (default, +errors when dealing with pseudo-random functions), `"same"` (replicates the results across the batch) or `"different"` +(each element of the batch is treated separately). +Relying on the default will typically result in an error such as this one: + + >>> RuntimeError: vmap: called random operation while in randomness error mode. + +Since the calls to `vmap` are buried down the loss modules, TorchRL +provides an interface to set that vmap mode from the outside through `loss.vmap_randomness = str_value`, see +:meth:`~torchrl.objectives.LossModule.vmap_randomness` for more information. + Training value functions ------------------------ diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index c62fd485e28..3d7b3df94cf 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -539,6 +539,18 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams @property def vmap_randomness(self): + """Vmap random mode. + + The vmap randomness mode controls what :func:`~torch.vmap` should do when dealing with + functions with a random outcome such as :func:`~torch.randn` and :func:`~torch.rand`. + If `"error"`, any random function will raise an exception indicating that `vmap` does not + know how to handle the random call. + + If `"different"`, every element of the batch along which vmap is being called will + behave differently. If `"same"`, vmaps will copy the same result across all elements. + + This property supports setting its value. + """ if self._vmap_randomness is None: main_modules = list(self.__dict__.values()) + list(self.children()) modules = ( @@ -557,6 +569,10 @@ def vmap_randomness(self): return self._vmap_randomness def set_vmap_randomness(self, value): + if value not in ("error", "same", "different"): + raise ValueError( + "Wrong vmap randomness, should be one of 'error', 'same' or 'different'." + ) self._vmap_randomness = value self._make_vmap()