Skip to content

Commit

Permalink
[Doc] Add doc about vmap randomness
Browse files Browse the repository at this point in the history
ghstack-source-id: f755100911ed5271de39a9500e1ffd69754dfad5
Pull Request resolved: #2316
  • Loading branch information
vmoens committed Jul 24, 2024
1 parent 21b297e commit 447970d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
16 changes: 16 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@ The main characteristics of TorchRL losses are:
If the parameters are modified out-of-place, :meth:`~torchrl.objectives.LossModule.from_stateful_net` can be
used to reset the parameters in the loss to the new value.

torch.vmap and randomness
-------------------------

TorchRL loss modules have plenty of calls to :func:`~torch.vmap` to amortize the cost of calling multiple similar models
in a loop, and instead vectorize these operations. `vmap` needs to be told explicitly what to do when random numbers
need to be generated within the call. To do this, a randomness mode need to be set and must be one of `"error"` (default,
errors when dealing with pseudo-random functions), `"same"` (replicates the results across the batch) or `"different"`
(each element of the batch is treated separately).
Relying on the default will typically result in an error such as this one:

>>> RuntimeError: vmap: called random operation while in randomness error mode.

Since the calls to `vmap` are buried down the loss modules, TorchRL
provides an interface to set that vmap mode from the outside through `loss.vmap_randomness = str_value`, see
:meth:`~torchrl.objectives.LossModule.vmap_randomness` for more information.

Training value functions
------------------------

Expand Down
16 changes: 16 additions & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,18 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams

@property
def vmap_randomness(self):
"""Vmap random mode.
The vmap randomness mode controls what :func:`~torch.vmap` should do when dealing with
functions with a random outcome such as :func:`~torch.randn` and :func:`~torch.rand`.
If `"error"`, any random function will raise an exception indicating that `vmap` does not
know how to handle the random call.
If `"different"`, every element of the batch along which vmap is being called will
behave differently. If `"same"`, vmaps will copy the same result across all elements.
This property supports setting its value.
"""
if self._vmap_randomness is None:
main_modules = list(self.__dict__.values()) + list(self.children())
modules = (
Expand All @@ -557,6 +569,10 @@ def vmap_randomness(self):
return self._vmap_randomness

def set_vmap_randomness(self, value):
if value not in ("error", "same", "different"):
raise ValueError(
"Wrong vmap randomness, should be one of 'error', 'same' or 'different'."
)
self._vmap_randomness = value
self._make_vmap()

Expand Down

0 comments on commit 447970d

Please sign in to comment.