diff --git a/test/test_cost.py b/test/test_cost.py index 669c500facb..94cb1930eca 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6191,7 +6191,7 @@ def zero_param(p): if isinstance(p, nn.Parameter): p.data.zero_() - params.apply(zero_param) + params.apply(zero_param, filter_empty=True) # assert len(list(floss_fn.parameters())) == 0 with params.to_module(loss_fn): diff --git a/test/test_env.py b/test/test_env.py index 802515e7850..e6ccbdcb5f1 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -446,22 +446,19 @@ def test_parallel_devices( env.shared_tensordict_parent.device.type == torch.device(edevice).type ) - @pytest.mark.parametrize("start_method", [None, "fork"]) - def test_serial_for_single(self, maybe_fork_ParallelEnv, start_method): + def test_serial_for_single(self, maybe_fork_ParallelEnv): env = ParallelEnv( 1, ContinuousActionVecMockEnv, serial_for_single=True, - mp_start_method=start_method, ) assert isinstance(env, SerialEnv) - env = ParallelEnv(1, ContinuousActionVecMockEnv, mp_start_method=start_method) + env = ParallelEnv(1, ContinuousActionVecMockEnv) assert isinstance(env, ParallelEnv) env = ParallelEnv( 2, ContinuousActionVecMockEnv, serial_for_single=True, - mp_start_method=start_method, ) assert isinstance(env, ParallelEnv) diff --git a/test/test_transforms.py b/test/test_transforms.py index 27f696a1dfc..029f717987b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2417,7 +2417,7 @@ def test_transform_rb(self, include_forward, rbclass): ) rb.extend(td) - storage = rb._storage._storage[:] + storage = rb._storage[:] assert storage["action"].shape[-1] == 7 td = rb.sample(10) @@ -8734,7 +8734,7 @@ def test_transform_rb(self, create_copy, inverse, rbclass): rb.append_transform(t) rb.extend(tensordict) - assert "a" in rb._storage._storage.keys() + assert "a" in rb._storage[:].keys() sample = rb.sample(2) if create_copy: assert "a" in sample.keys() diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b29b701738c..d22dd95a247 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -917,7 +917,6 @@ def select_and_clone(name, tensor): nested_keys=True, filter_empty=True, ) - del out["next"] if out.device != device: if device is None: diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index d8675585abe..43ddc56ec07 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -108,7 +108,7 @@ def metadata_from_env(env) -> EnvMetaData: def fill_device_map(name, val, device_map=device_map): device_map[name] = val.device - tensordict.named_apply(fill_device_map, nested_keys=True) + tensordict.named_apply(fill_device_map, nested_keys=True, filter_empty=True) return EnvMetaData( tensordict, specs, batch_size, env_str, device, batch_locked, device_map ) @@ -2843,7 +2843,7 @@ def _sync_device(self): sync_func = self.__dict__.get("_sync_device_val", None) if sync_func is None: device = self.device - if device.type != "cuda": + if device is not None and device.type != "cuda": if torch.cuda.is_available(): self._sync_device_val = torch.cuda.synchronize elif torch.backends.mps.is_available(): diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 2f543c2748f..3a334ba558f 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -322,8 +322,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out = TensorDict( obs_dict, batch_size=tensordict.batch_size, _run_checks=False ) - tensordict_out = tensordict_out.to(self.device, non_blocking=True) - self._sync_device() + if self.device is not None: + tensordict_out = tensordict_out.to(self.device, non_blocking=True) + self._sync_device() if self.info_dict_reader and (info_dict is not None): if not isinstance(info_dict, dict): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 3c908f9166d..30d306e478e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3684,7 +3684,7 @@ def __init__( if torch.cuda.is_available(): self._sync_device = torch.cuda.synchronize elif torch.backends.mps.is_available(): - self._sync_device = torch.cuda.synchronize + self._sync_device = torch.mps.synchronize elif device.type == "cpu": self._sync_device = _do_nothing else: @@ -3739,7 +3739,7 @@ def _sync_orig_device(self): if torch.cuda.is_available(): self._sync_orig_device_val = torch.cuda.synchronize elif torch.backends.mps.is_available(): - self._sync_orig_device_val = torch.cuda.synchronize + self._sync_orig_device_val = torch.mps.synchronize elif device.type == "cpu": self._sync_orig_device_val = _do_nothing else: diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 0499e110398..726ece08202 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -253,6 +253,7 @@ def _compare_and_expand(param): _compare_and_expand, batch_size=[expand_dim, *param.shape], call_on_nested=True, + filter_empty=True, ) if not isinstance(param, nn.Parameter): buffer = param.expand(expand_dim, *param.shape).clone() @@ -276,6 +277,7 @@ def _compare_and_expand(param): _compare_and_expand, batch_size=[expand_dim, *params.shape], call_on_nested=True, + filter_empty=True, ), no_convert=True, ) @@ -296,7 +298,9 @@ def _compare_and_expand(param): # set the functional module: we need to convert the params to non-differentiable params # otherwise they will appear twice in parameters with params.apply( - self._make_meta_params, device=torch.device("meta") + self._make_meta_params, + device=torch.device("meta"), + filter_empty=False, ).to_module(module): # avoid buffers and params being exposed self.__dict__[module_name] = deepcopy(module) @@ -306,7 +310,10 @@ def _compare_and_expand(param): # if create_target_params: # we create a TensorDictParams to keep the target params as Buffer instances target_params = TensorDictParams( - params.apply(_make_target_param(clone=create_target_params)), + params.apply( + _make_target_param(clone=create_target_params), + filter_empty=True, + ), no_convert=True, ) setattr(self, name_params_target + "_params", target_params) diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index b8ca8dfd9ed..f2dcb4f60c9 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -197,7 +197,9 @@ def __init__( actor_critic = ActorCriticWrapper(actor_network, value_network) params = TensorDict.from_module(actor_critic) - params_meta = params.apply(self._make_meta_params, device=torch.device("meta")) + params_meta = params.apply( + self._make_meta_params, device=torch.device("meta"), filter_empty=True + ) with params_meta.to_module(actor_critic): self.__dict__["actor_critic"] = deepcopy(actor_critic) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index f7b9307a962..02b142736c5 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -224,7 +224,9 @@ def __init__( global_value_network = SafeSequential(local_value_network, mixer_network) params = TensorDict.from_module(global_value_network) with params.apply( - self._make_meta_params, device=torch.device("meta") + self._make_meta_params, + device=torch.device("meta"), + filter_empty=True, ).to_module(global_value_network): self.__dict__["global_value_network"] = deepcopy(global_value_network)