From 9d2c5619711066c44b2118efd7a641911b0ae1ee Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Jul 2024 08:24:22 +0100 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- docs/source/reference/objectives.rst | 16 ++++++++++++++++ torchrl/objectives/common.py | 22 ++++++++++++++++++++++ torchrl/objectives/cql.py | 4 +++- torchrl/objectives/crossq.py | 5 ++++- torchrl/objectives/deprecated.py | 5 +++-- torchrl/objectives/iql.py | 5 ++++- torchrl/objectives/redq.py | 2 ++ torchrl/objectives/sac.py | 10 ++++++++-- torchrl/objectives/td3.py | 5 ++++- torchrl/objectives/td3_bc.py | 5 ++++- 10 files changed, 70 insertions(+), 9 deletions(-) diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 537d4542910..18cb6886914 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -38,6 +38,22 @@ The main characteristics of TorchRL losses are: If the parameters are modified out-of-place, :meth:`~torchrl.objectives.LossModule.from_stateful_net` can be used to reset the parameters in the loss to the new value. +torch.vmap and randomness +------------------------- + +TorchRL loss modules have plenty of calls to :func:`~torch.vmap` to amortize the cost of calling multiple similar models +in a loop, and instead vectorize these operations. `vmap` needs to be told explicitly what to do when random numbers +need to be generated within the call. To do this, a randomness mode need to be set and must be one of `"error"` (default, +errors when dealing with pseudo-random functions), `"same"` (replicates the results across the batch) or `"different"` +(each element of the batch is treated separately). +Relying on the default will typically result in an error such as this one: + + >>> RuntimeError: vmap: called random operation while in randomness error mode. + +Since the calls to `vmap` are buried down the loss modules, TorchRL +provides an interface to set that vmap mode from the outside through `loss.vmap_randomness = str_value`, see +:meth:`~torchrl.objectives.LossModule.vmap_randomness` for more information. + Training value functions ------------------------ diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index f2b02825005..998eeaedc15 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -540,6 +540,18 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams @property def vmap_randomness(self): + """Vmap random mode. + + The vmap randomness mode controls what :func:`~torch.vmap` should do when dealing with + functions with a random outcome such as :func:`~torch.randn` and :func:`~torch.rand`. + If `"error"`, any random function will raise an exception indicating that `vmap` does not + know how to handle the random call. + + If `"different"`, every element of the batch along which vmap is being called will + behave differently. If `"same"`, vmaps will copy the same result across all elements. + + This property supports setting its value. + """ if self._vmap_randomness is None: do_break = False for val in self.__dict__.values(): @@ -558,7 +570,12 @@ def vmap_randomness(self): return self._vmap_randomness def set_vmap_randomness(self, value): + if value not in ("error", "same", "different"): + raise ValueError( + "Wrong vmap randomness, should be one of 'error', 'same' or 'different'." + ) self._vmap_randomness = value + self._make_vmap() @staticmethod def _make_meta_params(param): @@ -570,6 +587,11 @@ def _make_meta_params(param): pd = nn.Parameter(pd, requires_grad=False) return pd + def _make_vmap(self): + raise NotImplementedError( + f"_make_vmap has been called but is not implemented for loss of type {type(self).__name__}." + ) + class _make_target_param: def __init__(self, clone): diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 0d2d869d1e1..96f37225fd8 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -374,13 +374,15 @@ def __init__( torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)), ) + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_networkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) - self.reduction = reduction @property def target_entropy(self): diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 355a33a4682..05499cb227d 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -338,10 +338,13 @@ def __init__( self._target_entropy = target_entropy self._action_spec = action_spec + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self.reduction = reduction @property def target_entropy_buffer(self): diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 9e7115ac601..b54e96eb32f 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -223,11 +223,12 @@ def __init__( self.gSDE = gSDE self.reduction = reduction - self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) - if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + def _make_vmap(self): + self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) + @property def target_entropy(self): target_entropy = self.target_entropy_buffer diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 7fab95a95ed..013435c9079 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -318,10 +318,13 @@ def __init__( self.loss_function = loss_function if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_networkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self.reduction = reduction @property def device(self) -> torch.device: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index a0aaa96f7c5..db05063535a 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -336,7 +336,9 @@ def __init__( self.gSDE = gSDE if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 67ab7d7d8ce..d03014da7bd 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -403,6 +403,10 @@ def __init__( ) if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) @@ -410,7 +414,6 @@ def __init__( self._vmap_qnetwork00 = _vmap_func( qvalue_network, randomness=self.vmap_randomness ) - self.reduction = reduction @property def target_entropy_buffer(self): @@ -1101,10 +1104,13 @@ def __init__( self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index db99237d39e..b0026b0158d 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -317,13 +317,16 @@ def __init__( self.register_buffer("min_action", low) if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) self._vmap_actor_network00 = _vmap_func( self.actor_network, randomness=self.vmap_randomness ) - self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index d5529e0b859..bea101f4038 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -331,13 +331,16 @@ def __init__( high = high.to(device) self.register_buffer("max_action", high) self.register_buffer("min_action", low) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) self._vmap_actor_network00 = _vmap_func( self.actor_network, randomness=self.vmap_randomness ) - self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: