Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jul 24, 2024
2 parents b257cf0 + ad39fc2 commit 400eb94
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ def _make_meta_params(param):
return pd

def _make_vmap(self):
"""Caches the the vmap callers to reduce the overhead at runtime."""
raise NotImplementedError(
f"_make_vmap has been called but is not implemented for loss of type {type(self).__name__}."
)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def __init__(
"log_alpha_prime",
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)

self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def __init__(
self._action_spec = action_spec
self.target_entropy_buffer = None
self.gSDE = gSDE
self._make_vmap()
self.reduction = reduction

if gamma is not None:
Expand Down

0 comments on commit 400eb94

Please sign in to comment.