From 6da8255a581537e14b6c1cc8015e2512f6788011 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Jul 2024 08:53:04 +0100 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- test/test_cost.py | 3 +++ torchrl/objectives/common.py | 21 ++++++++++----------- torchrl/objectives/sac.py | 2 +- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 9ac2bb6b950..090b32ac8e5 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -260,6 +260,9 @@ def __init__(self): net = nn.Sequential(*layers).to(device) model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"]) self.convert_to_functional(model, "model", expand_dim=4) + self._make_vmap() + + def _make_vmap(self): self.vmap_model = _vmap_func( self.model, (None, 0), diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 998eeaedc15..5036d35c96f 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -12,7 +12,6 @@ from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple -import torch from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams @@ -553,16 +552,16 @@ def vmap_randomness(self): This property supports setting its value. """ if self._vmap_randomness is None: - do_break = False - for val in self.__dict__.values(): - if isinstance(val, torch.nn.Module): - for module in val.modules(): - if isinstance(module, RANDOM_MODULE_LIST): - self._vmap_randomness = "different" - do_break = True - break - if do_break: - # double break + main_modules = list(self.__dict__.values()) + list(self.children()) + modules = ( + module + for main_module in main_modules + if isinstance(main_module, nn.Module) + for module in main_module.modules() + ) + for val in modules: + if isinstance(val, RANDOM_MODULE_LIST): + self._vmap_randomness = "different" break else: self._vmap_randomness = "error" diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index d03014da7bd..65482a2b876 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -412,7 +412,7 @@ def _make_vmap(self): ) if self._version == 1: self._vmap_qnetwork00 = _vmap_func( - qvalue_network, randomness=self.vmap_randomness + self.qvalue_network, randomness=self.vmap_randomness ) @property