Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jul 24, 2024
1 parent 9d2c561 commit 6da8255
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
3 changes: 3 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def __init__(self):
net = nn.Sequential(*layers).to(device)
model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"])
self.convert_to_functional(model, "model", expand_dim=4)
self._make_vmap()

def _make_vmap(self):
self.vmap_model = _vmap_func(
self.model,
(None, 0),
Expand Down
21 changes: 10 additions & 11 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple

import torch
from tensordict import is_tensor_collection, TensorDict, TensorDictBase

from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
Expand Down Expand Up @@ -553,16 +552,16 @@ def vmap_randomness(self):
This property supports setting its value.
"""
if self._vmap_randomness is None:
do_break = False
for val in self.__dict__.values():
if isinstance(val, torch.nn.Module):
for module in val.modules():
if isinstance(module, RANDOM_MODULE_LIST):
self._vmap_randomness = "different"
do_break = True
break
if do_break:
# double break
main_modules = list(self.__dict__.values()) + list(self.children())
modules = (
module
for main_module in main_modules
if isinstance(main_module, nn.Module)
for module in main_module.modules()
)
for val in modules:
if isinstance(val, RANDOM_MODULE_LIST):
self._vmap_randomness = "different"
break
else:
self._vmap_randomness = "error"
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _make_vmap(self):
)
if self._version == 1:
self._vmap_qnetwork00 = _vmap_func(
qvalue_network, randomness=self.vmap_randomness
self.qvalue_network, randomness=self.vmap_randomness
)

@property
Expand Down

0 comments on commit 6da8255

Please sign in to comment.