From 19a920eed2ff055079dd03ccac3dbf32e11da2e8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 5 Feb 2024 20:21:59 +0000 Subject: [PATCH] [BugFix] Fix update in serial / parallel env (#1866) --- test/mocking_classes.py | 19 +- test/test_collector.py | 18 +- test/test_env.py | 7 +- test/test_tensordictmodules.py | 35 ++- torchrl/collectors/collectors.py | 2 +- torchrl/data/replay_buffers/storages.py | 2 +- torchrl/data/rlhf/dataset.py | 2 +- torchrl/envs/batched_envs.py | 343 ++++++++++++++---------- torchrl/envs/common.py | 10 +- torchrl/envs/gym_like.py | 4 +- torchrl/envs/transforms/transforms.py | 10 +- torchrl/envs/utils.py | 2 +- torchrl/trainers/trainers.py | 2 +- 13 files changed, 278 insertions(+), 178 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 7a32c9a38ef..d68c7f30aa3 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1072,7 +1072,7 @@ def _step( tensordict: TensorDictBase, ) -> TensorDictBase: action = tensordict.get(self.action_key) - self.count += action.to(torch.int).to(self.device) + self.count += action.to(dtype=torch.int, device=self.device) tensordict = TensorDict( source={ "observation": self.count.clone(), @@ -1426,10 +1426,12 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): 3, ) ), + device=self.device, ) self.unbatched_action_spec = CompositeSpec( lazy=action_specs, + device=self.device, ) self.unbatched_reward_spec = CompositeSpec( { @@ -1441,7 +1443,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): }, shape=(self.n_nested_dim,), ) - } + }, + device=self.device, ) self.unbatched_done_spec = CompositeSpec( { @@ -1455,7 +1458,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): }, shape=(self.n_nested_dim,), ) - } + }, + device=self.device, ) self.action_spec = self.unbatched_action_spec.expand( @@ -1488,7 +1492,8 @@ def get_agent_obs_spec(self, i): "lidar": lidar, "vector": vector_3d, "tensor_0": tensor_0, - } + }, + device=self.device, ) elif i == 1: return CompositeSpec( @@ -1497,7 +1502,8 @@ def get_agent_obs_spec(self, i): "lidar": lidar, "vector": vector_2d, "tensor_1": tensor_1, - } + }, + device=self.device, ) elif i == 2: return CompositeSpec( @@ -1505,7 +1511,8 @@ def get_agent_obs_spec(self, i): "camera": camera, "vector": vector_2d, "tensor_2": tensor_2, - } + }, + device=self.device, ) else: raise ValueError(f"Index {i} undefined for index 3") diff --git a/test/test_collector.py b/test/test_collector.py index b5afe7f35d7..09c6ee293c3 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1675,8 +1675,12 @@ def test_maxframes_error(): @pytest.mark.parametrize("policy_device", [None, *get_available_devices()]) @pytest.mark.parametrize("env_device", [None, *get_available_devices()]) @pytest.mark.parametrize("storing_device", [None, *get_available_devices()]) +@pytest.mark.parametrize("parallel", [False, True]) def test_reset_heterogeneous_envs( - policy_device: torch.device, env_device: torch.device, storing_device: torch.device + policy_device: torch.device, + env_device: torch.device, + storing_device: torch.device, + parallel, ): if ( policy_device is not None @@ -1686,9 +1690,13 @@ def test_reset_heterogeneous_envs( env_device = torch.device("cpu") # explicit mapping elif env_device is not None and env_device.type == "cuda" and policy_device is None: policy_device = torch.device("cpu") - env1 = lambda: TransformedEnv(CountingEnv(), StepCounter(2)) - env2 = lambda: TransformedEnv(CountingEnv(), StepCounter(3)) - env = SerialEnv(2, [env1, env2], device=env_device) + env1 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(2)) + env2 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(3)) + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv + env = cls(2, [env1, env2], device=env_device) collector = SyncDataCollector( env, RandomPolicy(env.action_spec), @@ -1705,7 +1713,7 @@ def test_reset_heterogeneous_envs( assert ( data[0]["next", "truncated"].squeeze() == torch.tensor([False, True], device=data_device).repeat(25)[:50] - ).all(), data[0]["next", "truncated"][:10] + ).all(), data[0]["next", "truncated"] assert ( data[1]["next", "truncated"].squeeze() == torch.tensor([False, False, True], device=data_device).repeat(17)[:50] diff --git a/test/test_env.py b/test/test_env.py index 22918c390df..e316e1ae10f 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -2095,7 +2095,10 @@ def test_rollout_policy(self, batch_size, rollout_steps, count): @pytest.mark.parametrize("batch_size", [(1, 2)]) @pytest.mark.parametrize("env_type", ["serial", "parallel"]) - def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2): + @pytest.mark.parametrize("break_when_any_done", [False, True]) + def test_vec_env( + self, batch_size, env_type, break_when_any_done, rollout_steps=4, n_workers=2 + ): env_fun = lambda: HeterogeneousCountingEnv(batch_size=batch_size) if env_type == "serial": vec_env = SerialEnv(n_workers, env_fun) @@ -2109,7 +2112,7 @@ def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2): rollout_steps, policy=policy, return_contiguous=False, - break_when_any_done=False, + break_when_any_done=break_when_any_done, ) td = dense_stack_tds(td) for i in range(env_fun().n_nested_dim): diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 83a283e4e56..c2df40be012 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -21,6 +21,7 @@ CompositeSpec, UnboundedContinuousTensorSpec, ) +from torchrl.envs import EnvCreator, SerialEnv from torchrl.envs.utils import set_exploration_type, step_mdp from torchrl.modules import ( AdditiveGaussianWrapper, @@ -1782,9 +1783,12 @@ def test_multi_consecutive(self, shape, python_based): ) @pytest.mark.parametrize("python_based", [True, False]) - def test_lstm_parallel_env(self, python_based): + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("heterogeneous", [True, False]) + def test_lstm_parallel_env(self, python_based, parallel, heterogeneous): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv + torch.manual_seed(0) device = "cuda" if torch.cuda.device_count() else "cpu" # tests that hidden states are carried over with parallel envs lstm_module = LSTMModule( @@ -1796,6 +1800,10 @@ def test_lstm_parallel_env(self, python_based): device=device, python_based=python_based, ) + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv def create_transformed_env(): primer = lstm_module.make_tensordict_primer() @@ -1807,7 +1815,12 @@ def create_transformed_env(): env.append_transform(primer) return env - env = ParallelEnv( + if heterogeneous: + create_transformed_env = [ + EnvCreator(create_transformed_env), + EnvCreator(create_transformed_env), + ] + env = cls( create_env_fn=create_transformed_env, num_workers=2, ) @@ -2109,9 +2122,13 @@ def test_multi_consecutive(self, shape, python_based): ) @pytest.mark.parametrize("python_based", [True, False]) - def test_gru_parallel_env(self, python_based): + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("heterogeneous", [True, False]) + def test_gru_parallel_env(self, python_based, parallel, heterogeneous): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv + torch.manual_seed(0) + device = "cuda" if torch.cuda.device_count() else "cpu" # tests that hidden states are carried over with parallel envs gru_module = GRUModule( @@ -2134,7 +2151,17 @@ def create_transformed_env(): env.append_transform(primer) return env - env = ParallelEnv( + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv + if heterogeneous: + create_transformed_env = [ + EnvCreator(create_transformed_env), + EnvCreator(create_transformed_env), + ] + + env = cls( create_env_fn=create_transformed_env, num_workers=2, ) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index eff2434d487..bea46bb6cd4 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1077,7 +1077,7 @@ def rollout(self) -> TensorDictBase: if self.storing_device is not None: tensordicts.append( - self._shuttle.to(self.storing_device, non_blocking=False) + self._shuttle.to(self.storing_device, non_blocking=True) ) else: tensordicts.append(self._shuttle) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index a1ad94c21fe..55c57a6a6b4 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -894,7 +894,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any: # to be deprecated in v0.4 def map_device(tensor): if tensor.device != self.device: - return tensor.to(self.device, non_blocking=False) + return tensor.to(self.device, non_blocking=True) return tensor if is_tensor_collection(result): diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index 19090d3f4c5..8f039b317fc 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -394,7 +394,7 @@ def get_dataloader( ) out = TensorDictReplayBuffer( storage=TensorStorage(data), - collate_fn=lambda x: x.as_tensor().to(device, non_blocking=False), + collate_fn=lambda x: x.as_tensor().to(device, non_blocking=True), sampler=SamplerWithoutReplacement(drop_last=True), batch_size=batch_size, prefetch=prefetch, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 5e88cf4e86d..9669963cb33 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -8,6 +8,7 @@ import gc import os +import weakref from collections import OrderedDict from copy import deepcopy from functools import wraps @@ -19,7 +20,7 @@ import torch from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase -from tensordict._tensordict import _unravel_key_to_tuple, unravel_key +from tensordict._tensordict import unravel_key from torch import multiprocessing as mp from torchrl._utils import ( _check_for_faulty_process, @@ -40,7 +41,6 @@ from torchrl.envs.utils import ( _aggregate_end_of_traj, - _set_single_key, _sort_keys, _update_during_reset, clear_mpi_env_vars, @@ -419,7 +419,13 @@ def _check_for_empty_spec(specs: CompositeSpec): def map_device(key, value, device_map=device_map): return value.to(device_map[key]) - self._env_tensordict.named_apply(map_device, nested_keys=True) + # self._env_tensordict.named_apply( + # map_device, nested_keys=True, filter_empty=True + # ) + self._env_tensordict.named_apply( + map_device, + nested_keys=True, + ) self._batch_locked = meta_data.batch_locked else: @@ -535,22 +541,17 @@ def _create_td(self) -> None: self._selected_keys = self._selected_keys.union(reset_keys) # input keys - self._selected_input_keys = { - _unravel_key_to_tuple(key) for key in self._env_input_keys - } + self._selected_input_keys = {unravel_key(key) for key in self._env_input_keys} # output keys after reset self._selected_reset_keys = { - _unravel_key_to_tuple(key) - for key in self._env_obs_keys + self.done_keys + reset_keys + unravel_key(key) for key in self._env_obs_keys + self.done_keys + reset_keys } # output keys after reset, filtered self._selected_reset_keys_filt = { unravel_key(key) for key in self._env_obs_keys + self.done_keys } # output keys after step - self._selected_step_keys = { - _unravel_key_to_tuple(key) for key in self._env_output_keys - } + self._selected_step_keys = {unravel_key(key) for key in self._env_output_keys} if self._single_task: shared_tensordict_parent = shared_tensordict_parent.select( @@ -689,11 +690,27 @@ def _start_workers(self) -> None: _num_workers = self.num_workers self._envs = [] - + weakref_set = set() for idx in range(_num_workers): env = self.create_env_fn[idx](**self.create_env_kwargs[idx]) - if self.device is not None: - env = env.to(self.device) + # We want to avoid having the same env multiple times + # so we try to deepcopy it if needed. If we can't, we make + # the user aware that this isn't a very good idea + wr = weakref.ref(env) + if wr in weakref_set: + try: + env = deepcopy(env) + except Exception: + warn( + "Deepcopying the env failed within SerialEnv " + "but more than one copy of the same env was found. " + "This is a dangerous situation if your env keeps track " + "of some variables (e.g., state) in-place. " + "We'll use the same copy of the environment be beaware that " + "this may have important, unwanted issues for stateful " + "environments!" + ) + weakref_set.add(wr) self._envs.append(env) self.is_closed = False @@ -755,8 +772,8 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_ = None else: env_device = _env.device - if env_device != self.device: - tensordict_ = tensordict_.to(env_device) + if env_device != self.device and env_device is not None: + tensordict_ = tensordict_.to(env_device, non_blocking=True) else: tensordict_ = tensordict_.clone(False) else: @@ -764,30 +781,33 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: _td = _env.reset(tensordict=tensordict_, **kwargs) self.shared_tensordicts[i].update_( - _td.select(*self._selected_reset_keys_filt, strict=False) + _td, + keys_to_update=list(self._selected_reset_keys_filt), ) selected_output_keys = self._selected_reset_keys_filt device = self.device - if self._single_task: - # select + clone creates 2 tds, but we can create one only - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=device - ) - for key in selected_output_keys: - _set_single_key( - self.shared_tensordict_parent, out, key, clone=True, device=device - ) - else: - out = self.shared_tensordict_parent.select( - *selected_output_keys, - strict=False, - ) - if out.device == device: - out = out.clone() - elif device is None: - out = out.clone().clear_device_() + + # select + clone creates 2 tds, but we can create one only + def select_and_clone(name, tensor): + if name in selected_output_keys: + return tensor.clone() + + # out = self.shared_tensordict_parent.named_apply( + # select_and_clone, + # nested_keys=True, + # filter_empty=True, + # ) + out = self.shared_tensordict_parent.named_apply( + select_and_clone, + nested_keys=True, + ) + del out["next"] + + if out.device != device: + if device is None: + out = out.clear_device_() else: - out = out.to(device, non_blocking=False) + out = out.to(device, non_blocking=True) return out def _reset_proc_data(self, tensordict, tensordict_reset): @@ -807,30 +827,29 @@ def _step( # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. env_device = self._envs[i].device - if env_device != self.device: - data_in = tensordict_in[i].to(env_device, non_blocking=False) + if env_device != self.device and env_device is not None: + data_in = tensordict_in[i].to(env_device, non_blocking=True) else: data_in = tensordict_in[i] out_td = self._envs[i]._step(data_in) - next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) + next_td[i].update_(out_td, keys_to_update=list(self._env_output_keys)) + # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps device = self.device - if self._single_task: - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=device - ) - for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True, device=device) - else: - # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False) - if out.device == device: - out = out.clone() - elif device is None: - out = out.clone().clear_device_() - else: - out = out.to(device, non_blocking=False) + + def select_and_clone(name, tensor): + if name in self._selected_step_keys: + return tensor.clone() + + # out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True) + out = next_td.named_apply(select_and_clone, nested_keys=True) + + if out.device != device: + if device is None: + out = out.clear_device_() + elif out.device != device: + out = out.to(device, non_blocking=True) return out def __getattr__(self, attr: str) -> Any: @@ -1040,6 +1059,7 @@ def _start_workers(self) -> None: def look_for_cuda(tensor, has_cuda=has_cuda): has_cuda[0] = has_cuda[0] or tensor.is_cuda + # self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True) self.shared_tensordict_parent.apply(look_for_cuda) has_cuda = has_cuda[0] if has_cuda: @@ -1119,32 +1139,29 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: def step_and_maybe_reset( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - if self._single_task and not self.has_lazy_inputs: - # We must use the in_keys and nothing else for the following reasons: - # - efficiency: copying all the keys will in practice mean doing a lot - # of writing operations since the input tensordict may (and often will) - # contain all the previous output data. - # - value mismatch: if the batched env is placed within a transform - # and this transform overrides an observation key (eg, CatFrames) - # the shape, dtype or device may not necessarily match and writing - # the value in-place will fail. - for key in self._env_input_keys: - self.shared_tensordict_parent.set_(key, tensordict.get(key)) - next_td = tensordict.get("next", None) - if next_td is not None: - # we copy the input keys as well as the keys in the 'next' td, if any - # as this mechanism can be used by a policy to set anticipatively the - # keys of the next call (eg, with recurrent nets) - for key in next_td.keys(True, True): - key = unravel_key(("next", key)) - if key in self.shared_tensordict_parent.keys(True, True): - self.shared_tensordict_parent.set_(key, next_td.get(key[1:])) + # We must use the in_keys and nothing else for the following reasons: + # - efficiency: copying all the keys will in practice mean doing a lot + # of writing operations since the input tensordict may (and often will) + # contain all the previous output data. + # - value mismatch: if the batched env is placed within a transform + # and this transform overrides an observation key (eg, CatFrames) + # the shape, dtype or device may not necessarily match and writing + # the value in-place will fail. + self.shared_tensordict_parent.update_( + tensordict, keys_to_update=self._env_input_keys + ) + next_td_passthrough = tensordict.get("next", None) + if next_td_passthrough is not None: + # if we have input "next" data (eg, RNNs which pass the next state) + # the sub-envs will need to process them through step_and_maybe_reset. + # We keep track of which keys are present to let the worker know what + # should be passd to the env (we don't want to pass done states for instance) + next_td_keys = list(next_td_passthrough.keys(True, True)) + self.shared_tensordict_parent.get("next").update_(next_td_passthrough) else: - self.shared_tensordict_parent.update_( - tensordict.select(*self._env_input_keys, "next", strict=False) - ) + next_td_keys = None for i in range(self.num_workers): - self.parent_channels[i].send(("step_and_maybe_reset", None)) + self.parent_channels[i].send(("step_and_maybe_reset", next_td_keys)) for i in range(self.num_workers): event = self._events[i] @@ -1160,8 +1177,20 @@ def step_and_maybe_reset( next_td = next_td.clone() tensordict_ = tensordict_.clone() elif device is not None: - next_td = next_td.to(device, non_blocking=False) - tensordict_ = tensordict_.to(device, non_blocking=False) + next_td = next_td._fast_apply( + lambda x: x.to(device, non_blocking=True) + if x.device != device + else x.clone(), + device=device, + # filter_empty=True, + ) + tensordict_ = tensordict_._fast_apply( + lambda x: x.to(device, non_blocking=True) + if x.device != device + else x.clone(), + device=device, + # filter_empty=True, + ) else: next_td = next_td.clone().clear_device_() tensordict_ = tensordict_.clone().clear_device_() @@ -1170,35 +1199,33 @@ def step_and_maybe_reset( @_check_start def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - if self._single_task and not self.has_lazy_inputs: - # We must use the in_keys and nothing else for the following reasons: - # - efficiency: copying all the keys will in practice mean doing a lot - # of writing operations since the input tensordict may (and often will) - # contain all the previous output data. - # - value mismatch: if the batched env is placed within a transform - # and this transform overrides an observation key (eg, CatFrames) - # the shape, dtype or device may not necessarily match and writing - # the value in-place will fail. - for key in tensordict.keys(True, True): - # we copy the input keys as well as the keys in the 'next' td, if any - # as this mechanism can be used by a policy to set anticipatively the - # keys of the next call (eg, with recurrent nets) - if key in self._env_input_keys or ( - isinstance(key, tuple) - and key[0] == "next" - and key in self.shared_tensordict_parent.keys(True, True) - ): - val = tensordict.get(key) - self.shared_tensordict_parent.set_(key, val) + # We must use the in_keys and nothing else for the following reasons: + # - efficiency: copying all the keys will in practice mean doing a lot + # of writing operations since the input tensordict may (and often will) + # contain all the previous output data. + # - value mismatch: if the batched env is placed within a transform + # and this transform overrides an observation key (eg, CatFrames) + # the shape, dtype or device may not necessarily match and writing + # the value in-place will fail. + self.shared_tensordict_parent.update_( + tensordict, keys_to_update=list(self._env_input_keys) + ) + next_td_passthrough = tensordict.get("next", None) + if next_td_passthrough is not None: + # if we have input "next" data (eg, RNNs which pass the next state) + # the sub-envs will need to process them through step_and_maybe_reset. + # We keep track of which keys are present to let the worker know what + # should be passd to the env (we don't want to pass done states for instance) + next_td_keys = list(next_td_passthrough.keys(True, True)) + self.shared_tensordict_parent.get("next").update_(next_td_passthrough) else: - self.shared_tensordict_parent.update_( - tensordict.select(*self._env_input_keys, "next", strict=False) - ) + next_td_keys = None + if self.event is not None: self.event.record() self.event.synchronize() for i in range(self.num_workers): - self.parent_channels[i].send(("step", None)) + self.parent_channels[i].send(("step", next_td_keys)) for i in range(self.num_workers): event = self._events[i] @@ -1209,19 +1236,21 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # will be modified in-place at further steps next_td = self.shared_tensordict_parent.get("next") device = self.device - if self._single_task: - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=device - ) - for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True, device=device) - else: - # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False) - if out.device == device: - out = out.clone() + + def select_and_clone(name, tensor): + if name in self._selected_step_keys: + return tensor.clone() + + out = next_td.named_apply( + select_and_clone, + nested_keys=True, + # filter_empty=True, + ) + if out.device != device: + if device is None: + out.clear_device_() else: - out = out.to(device, non_blocking=False) + out = out.to(device, non_blocking=True) return out @_check_start @@ -1258,13 +1287,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # step at the root (since the shared_tensordict did not go through # step_mdp). self.shared_tensordicts[i].update_( - self.shared_tensordicts[i] - .get("next") - .select(*self._selected_reset_keys, strict=False) + self.shared_tensordicts[i].get("next"), + keys_to_update=list(self._selected_reset_keys), ) if tensordict_ is not None: self.shared_tensordicts[i].update_( - tensordict_.select(*self._selected_reset_keys, strict=False) + tensordict_, keys_to_update=list(self._selected_reset_keys) ) continue out = ("reset", tensordict_) @@ -1278,26 +1306,23 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: selected_output_keys = self._selected_reset_keys_filt device = self.device - if self._single_task: - # select + clone creates 2 tds, but we can create one only - out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=device - ) - for key in selected_output_keys: - _set_single_key( - self.shared_tensordict_parent, out, key, clone=True, device=device - ) - else: - out = self.shared_tensordict_parent.select( - *selected_output_keys, - strict=False, - ) - if out.device == device: - out = out.clone() - elif device is None: - out = out.clear_device_().clone() + + def select_and_clone(name, tensor): + if name in selected_output_keys: + return tensor.clone() + + out = self.shared_tensordict_parent.named_apply( + select_and_clone, + nested_keys=True, + # filter_empty=True, + ) + del out["next"] + + if out.device != device: + if device is None: + out.clear_device_() else: - out = out.to(device, non_blocking=False) + out = out.to(device, non_blocking=True) return out @_check_start @@ -1427,6 +1452,7 @@ def _run_worker_pipe_shared_mem( def look_for_cuda(tensor, has_cuda=has_cuda): has_cuda[0] = has_cuda[0] or tensor.is_cuda + # shared_tensordict.apply(look_for_cuda, filter_empty=True) shared_tensordict.apply(look_for_cuda) has_cuda = has_cuda[0] else: @@ -1498,7 +1524,8 @@ def look_for_cuda(tensor, has_cuda=has_cuda): raise RuntimeError("call 'init' before resetting") cur_td = env.reset(tensordict=data) shared_tensordict.update_( - cur_td.select(*_selected_reset_keys, strict=False) + cur_td, + keys_to_update=list(_selected_reset_keys), ) if event is not None: event.record() @@ -1510,7 +1537,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if not initialized: raise RuntimeError("called 'init' before step") i += 1 - next_td = env._step(shared_tensordict) + # No need to copy here since we don't write in-place + if data: + next_td_passthrough_keys = data + input = root_shared_tensordict.set( + "next", next_shared_tensordict.select(*next_td_passthrough_keys) + ) + else: + input = root_shared_tensordict + next_td = env._step(input) next_shared_tensordict.update_(next_td) if event is not None: event.record() @@ -1522,9 +1557,25 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if not initialized: raise RuntimeError("called 'init' before step") i += 1 - td, root_next_td = env.step_and_maybe_reset(shared_tensordict) - next_shared_tensordict.update_(td.get("next")) + # We must copy the root shared td here, or at least get rid of done: + # if we don't `td is root_shared_tensordict` + # which means that root_shared_tensordict will carry the content of next + # in the next iteration. When using StepCounter, it will look for an + # existing done state, find it and consider the env as done by input (not + # by output) of the step! + # Caveat: for RNN we may need some keys of the "next" TD so we pass the list + # through data + if data: + next_td_passthrough_keys = data + input = root_shared_tensordict.set( + "next", next_shared_tensordict.select(*next_td_passthrough_keys) + ) + else: + input = root_shared_tensordict + td, root_next_td = env.step_and_maybe_reset(input) + next_shared_tensordict.update_(td.pop("next")) root_shared_tensordict.update_(root_next_td) + if event is not None: event.record() event.synchronize() @@ -1588,5 +1639,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda): def _update_cuda(t_dest, t_source): if t_source is None: return - t_dest.copy_(t_source.pin_memory(), non_blocking=False) + t_dest.copy_(t_source.pin_memory(), non_blocking=True) return + + +def _filter_empty(tensordict): + return tensordict.select(*tensordict.keys(True, True)) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b2b201922e1..61cd211b6ae 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2056,7 +2056,7 @@ def reset( tensordict_reset = self._reset(tensordict, **kwargs) # We assume that this is done properly # if tensordict_reset.device != self.device: - # tensordict_reset = tensordict_reset.to(self.device, non_blocking=False) + # tensordict_reset = tensordict_reset.to(self.device, non_blocking=True) if tensordict_reset is tensordict: raise RuntimeError( "EnvBase._reset should return outplace changes to the input " @@ -2418,13 +2418,13 @@ def _rollout_stop_early( for i in range(max_steps): if auto_cast_to_device: if policy_device is not None: - tensordict = tensordict.to(policy_device, non_blocking=False) + tensordict = tensordict.to(policy_device, non_blocking=True) else: tensordict.clear_device_() tensordict = policy(tensordict) if auto_cast_to_device: if env_device is not None: - tensordict = tensordict.to(env_device, non_blocking=False) + tensordict = tensordict.to(env_device, non_blocking=True) else: tensordict.clear_device_() tensordict = self.step(tensordict) @@ -2472,13 +2472,13 @@ def _rollout_nonstop( for i in range(max_steps): if auto_cast_to_device: if policy_device is not None: - tensordict_ = tensordict_.to(policy_device, non_blocking=False) + tensordict_ = tensordict_.to(policy_device, non_blocking=True) else: tensordict_.clear_device_() tensordict_ = policy(tensordict_) if auto_cast_to_device: if env_device is not None: - tensordict_ = tensordict_.to(env_device, non_blocking=False) + tensordict_ = tensordict_.to(env_device, non_blocking=True) else: tensordict_.clear_device_() tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 38995a07a6b..d3b3dfd659c 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -322,7 +322,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out = TensorDict( obs_dict, batch_size=tensordict.batch_size, _run_checks=False ) - tensordict_out = tensordict_out.to(self.device, non_blocking=False) + tensordict_out = tensordict_out.to(self.device, non_blocking=True) if self.info_dict_reader and (info_dict is not None): if not isinstance(info_dict, dict): @@ -366,7 +366,7 @@ def _reset( for key, item in self.observation_spec.items(True, True): if key not in tensordict_out.keys(True, True): tensordict_out[key] = item.zero() - tensordict_out = tensordict_out.to(self.device, non_blocking=False) + tensordict_out = tensordict_out.to(self.device, non_blocking=True) return tensordict_out @abc.abstractmethod diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index a661b152d39..efa59e25c26 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3612,10 +3612,10 @@ def set_container(self, container: Union[Transform, EnvBase]) -> None: @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return tensordict.to(self.device, non_blocking=False) + return tensordict.to(self.device, non_blocking=True) def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - return tensordict.to(self.device, non_blocking=False) + return tensordict.to(self.device, non_blocking=True) def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase @@ -3628,8 +3628,8 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: if parent is None: if self.orig_device is None: return tensordict - return tensordict.to(self.orig_device, non_blocking=False) - return tensordict.to(parent.device, non_blocking=False) + return tensordict.to(self.orig_device, non_blocking=True) + return tensordict.to(parent.device, non_blocking=True) def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: return input_spec.to(self.device) @@ -5146,7 +5146,7 @@ def _reset( if step_count is None: step_count = self.container.observation_spec[step_count_key].zero() if step_count.device != reset.device: - step_count = step_count.to(reset.device, non_blocking=False) + step_count = step_count.to(reset.device, non_blocking=True) # zero the step count if reset is needed step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ebb9100655c..46c923ccfec 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -268,7 +268,7 @@ def _set_single_key( dest = new_val else: if device is not None and val.device != device: - val = val.to(device, non_blocking=False) + val = val.to(device, non_blocking=True) elif clone: val = val.clone() dest._set_str(k, val, inplace=False, validated=True) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index f844613432c..03a7be37573 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -708,7 +708,7 @@ def extend(self, batch: TensorDictBase) -> TensorDictBase: def sample(self, batch: TensorDictBase) -> TensorDictBase: sample = self.replay_buffer.sample(batch_size=self.batch_size) - return sample.to(self.device, non_blocking=False) + return sample.to(self.device, non_blocking=True) def update_priority(self, batch: TensorDictBase) -> None: self.replay_buffer.update_tensordict_priority(batch)