Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jul 24, 2024
1 parent 6da8255 commit ad39fc2
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def _make_meta_params(param):
return pd

def _make_vmap(self):
"""Caches the the vmap callers to reduce the overhead at runtime."""
raise NotImplementedError(
f"_make_vmap has been called but is not implemented for loss of type {type(self).__name__}."
)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def __init__(
"log_alpha_prime",
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)

self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def __init__(
self._action_spec = action_spec
self.target_entropy_buffer = None
self.gSDE = gSDE
self._make_vmap()
self.reduction = reduction

if gamma is not None:
Expand Down

0 comments on commit ad39fc2

Please sign in to comment.