Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jul 24, 2024
1 parent f840a1a commit 9d2c561
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 9 deletions.
16 changes: 16 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@ The main characteristics of TorchRL losses are:
If the parameters are modified out-of-place, :meth:`~torchrl.objectives.LossModule.from_stateful_net` can be
used to reset the parameters in the loss to the new value.

torch.vmap and randomness
-------------------------

TorchRL loss modules have plenty of calls to :func:`~torch.vmap` to amortize the cost of calling multiple similar models
in a loop, and instead vectorize these operations. `vmap` needs to be told explicitly what to do when random numbers
need to be generated within the call. To do this, a randomness mode need to be set and must be one of `"error"` (default,
errors when dealing with pseudo-random functions), `"same"` (replicates the results across the batch) or `"different"`
(each element of the batch is treated separately).
Relying on the default will typically result in an error such as this one:

>>> RuntimeError: vmap: called random operation while in randomness error mode.

Since the calls to `vmap` are buried down the loss modules, TorchRL
provides an interface to set that vmap mode from the outside through `loss.vmap_randomness = str_value`, see
:meth:`~torchrl.objectives.LossModule.vmap_randomness` for more information.

Training value functions
------------------------

Expand Down
22 changes: 22 additions & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,18 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams

@property
def vmap_randomness(self):
"""Vmap random mode.
The vmap randomness mode controls what :func:`~torch.vmap` should do when dealing with
functions with a random outcome such as :func:`~torch.randn` and :func:`~torch.rand`.
If `"error"`, any random function will raise an exception indicating that `vmap` does not
know how to handle the random call.
If `"different"`, every element of the batch along which vmap is being called will
behave differently. If `"same"`, vmaps will copy the same result across all elements.
This property supports setting its value.
"""
if self._vmap_randomness is None:
do_break = False
for val in self.__dict__.values():
Expand All @@ -558,7 +570,12 @@ def vmap_randomness(self):
return self._vmap_randomness

def set_vmap_randomness(self, value):
if value not in ("error", "same", "different"):
raise ValueError(
"Wrong vmap randomness, should be one of 'error', 'same' or 'different'."
)
self._vmap_randomness = value
self._make_vmap()

@staticmethod
def _make_meta_params(param):
Expand All @@ -570,6 +587,11 @@ def _make_meta_params(param):
pd = nn.Parameter(pd, requires_grad=False)
return pd

def _make_vmap(self):
raise NotImplementedError(
f"_make_vmap has been called but is not implemented for loss of type {type(self).__name__}."
)


class _make_target_param:
def __init__(self, clone):
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,15 @@ def __init__(
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)

self.reduction = reduction

def _make_vmap(self):
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def target_entropy(self):
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,13 @@ def __init__(

self._target_entropy = target_entropy
self._action_spec = action_spec
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def target_entropy_buffer(self):
Expand Down
5 changes: 3 additions & 2 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,12 @@ def __init__(
self.gSDE = gSDE
self.reduction = reduction

self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))

if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)

def _make_vmap(self):
self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))

@property
def target_entropy(self):
target_entropy = self.target_entropy_buffer
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,13 @@ def __init__(
self.loss_function = loss_function
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def device(self) -> torch.device:
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def __init__(
self.gSDE = gSDE
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()

def _make_vmap(self):
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
Expand Down
10 changes: 8 additions & 2 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,14 +403,17 @@ def __init__(
)
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
if self._version == 1:
self._vmap_qnetwork00 = _vmap_func(
qvalue_network, randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def target_entropy_buffer(self):
Expand Down Expand Up @@ -1101,10 +1104,13 @@ def __init__(
self.register_buffer(
"target_entropy", torch.tensor(target_entropy, device=device)
)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self.reduction = reduction

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,16 @@ def __init__(
self.register_buffer("min_action", low)
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self._vmap_actor_network00 = _vmap_func(
self.actor_network, randomness=self.vmap_randomness
)
self.reduction = reduction

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/td3_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,16 @@ def __init__(
high = high.to(device)
self.register_buffer("max_action", high)
self.register_buffer("min_action", low)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self._vmap_actor_network00 = _vmap_func(
self.actor_network, randomness=self.vmap_randomness
)
self.reduction = reduction

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down

0 comments on commit 9d2c561

Please sign in to comment.