From bd5c17e3480ca0e4ca8240f15faf1bafe9c7b3c1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 14 Jun 2024 17:30:00 +0100 Subject: [PATCH 1/3] init --- torchrl/envs/batched_envs.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index ee87df39ac1..6c53b5ebabe 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1392,7 +1392,7 @@ def _step_and_maybe_reset_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - for i, _data in enumerate(tensordict.unbind(0)): + for i, _data in enumerate(tensordict.consolidate().unbind(0)): self.parent_channels[i].send(("step_and_maybe_reset", _data)) results = [None] * self.num_workers @@ -1489,7 +1489,7 @@ def step_and_maybe_reset( def _step_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - for i, data in enumerate(tensordict.unbind(0)): + for i, data in enumerate(tensordict.consolidate().unbind(0)): self.parent_channels[i].send(("step", data)) out_tds = [] for i, channel in enumerate(self.parent_channels): @@ -1576,7 +1576,7 @@ def _reset_no_buffers( needs_resetting, ) -> Tuple[TensorDictBase, TensorDictBase]: tdunbound = ( - tensordict.unbind(0) + tensordict.consolidate().unbind(0) if is_tensor_collection(tensordict) else [None] * self.num_workers ) @@ -1895,10 +1895,11 @@ def look_for_cuda(tensor, has_cuda=has_cuda): i = 0 next_shared_tensordict = shared_tensordict.get("next") root_shared_tensordict = shared_tensordict.exclude("next") - if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): - raise RuntimeError( - "tensordict must be placed in shared memory (share_memory_() or memmap_())" - ) + # TODO: restore this + # if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): + # raise RuntimeError( + # "tensordict must be placed in shared memory (share_memory_() or memmap_())" + # ) shared_tensordict = shared_tensordict.clone(False).unlock_() initialized = True @@ -2130,7 +2131,7 @@ def _run_worker_pipe_direct( event.record() event.synchronize() mp_event.set() - child_pipe.send(cur_td) + child_pipe.send(cur_td.consolidate()) del cur_td elif cmd == "step": @@ -2142,7 +2143,7 @@ def _run_worker_pipe_direct( event.record() event.synchronize() mp_event.set() - child_pipe.send(next_td) + child_pipe.send(next_td.consolidate()) del next_td elif cmd == "step_and_maybe_reset": From 38943aa82825a0330d722ce757457dc223a8ef5f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 21 Jun 2024 14:57:14 +0100 Subject: [PATCH 2/3] amend --- torchrl/envs/batched_envs.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 6c53b5ebabe..c4654caa20d 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1392,8 +1392,15 @@ def _step_and_maybe_reset_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - for i, _data in enumerate(tensordict.consolidate().unbind(0)): - self.parent_channels[i].send(("step_and_maybe_reset", _data)) + td = tensordict.consolidate( + share_memory=True, inplace=True, num_threads=1 + ) + for i in range(td.shape[0]): + # We send the same td multiple times as it is in shared mem and we just need to index it + # in each process. + # If we don't do this, we need to unbind it but then the custom pickler will require + # some extra metadata to be collected. + self.parent_channels[i].send(("step_and_maybe_reset", (td, i))) results = [None] * self.num_workers @@ -1489,7 +1496,11 @@ def step_and_maybe_reset( def _step_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - for i, data in enumerate(tensordict.consolidate().unbind(0)): + for i, data in enumerate( + tensordict.consolidate( + share_memory=True, inplace=True, num_threads=1 + ).unbind(0) + ): self.parent_channels[i].send(("step", data)) out_tds = [] for i, channel in enumerate(self.parent_channels): @@ -1576,7 +1587,7 @@ def _reset_no_buffers( needs_resetting, ) -> Tuple[TensorDictBase, TensorDictBase]: tdunbound = ( - tensordict.consolidate().unbind(0) + tensordict.consolidate(share_memory=True, num_threads=1).unbind(0) if is_tensor_collection(tensordict) else [None] * self.num_workers ) @@ -2131,7 +2142,9 @@ def _run_worker_pipe_direct( event.record() event.synchronize() mp_event.set() - child_pipe.send(cur_td.consolidate()) + child_pipe.send( + cur_td.consolidate(share_memory=True, inplace=True, num_threads=1) + ) del cur_td elif cmd == "step": @@ -2143,13 +2156,17 @@ def _run_worker_pipe_direct( event.record() event.synchronize() mp_event.set() - child_pipe.send(next_td.consolidate()) + child_pipe.send( + next_td.consolidate(share_memory=True, inplace=True, num_threads=1) + ) del next_td elif cmd == "step_and_maybe_reset": if not initialized: raise RuntimeError("called 'init' before step") i += 1 + data, idx = data + data = data[idx] data._fast_apply( lambda x: x.clone() if x.device.type == "cuda" else x, out=data ) From 5fffe1e4a0c22a65ffa74e4b2e4d03afe24ea6ab Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 24 Jun 2024 13:19:54 +0100 Subject: [PATCH 3/3] amend --- torchrl/envs/batched_envs.py | 46 +++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index c4654caa20d..e7636245221 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1392,9 +1392,7 @@ def _step_and_maybe_reset_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - td = tensordict.consolidate( - share_memory=True, inplace=True, num_threads=1 - ) + td = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1) for i in range(td.shape[0]): # We send the same td multiple times as it is in shared mem and we just need to index it # in each process. @@ -1496,12 +1494,11 @@ def step_and_maybe_reset( def _step_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - for i, data in enumerate( - tensordict.consolidate( - share_memory=True, inplace=True, num_threads=1 - ).unbind(0) - ): - self.parent_channels[i].send(("step", data)) + data = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1) + for i, local_data in enumerate(data.unbind(0)): + self.parent_channels[i].send(("step", local_data)) + # for i in range(data.shape[0]): + # self.parent_channels[i].send(("step", (data, i))) out_tds = [] for i, channel in enumerate(self.parent_channels): self._events[i].wait() @@ -1586,17 +1583,24 @@ def _reset_no_buffers( reset_kwargs_list, needs_resetting, ) -> Tuple[TensorDictBase, TensorDictBase]: - tdunbound = ( - tensordict.consolidate(share_memory=True, num_threads=1).unbind(0) - if is_tensor_collection(tensordict) - else [None] * self.num_workers - ) + if is_tensor_collection(tensordict): + # tensordict = tensordict.consolidate(share_memory=True, num_threads=1) + tensordict = tensordict.consolidate( + share_memory=True, num_threads=1 + ).unbind(0) + else: + tensordict = [None] * self.num_workers out_tds = [None] * self.num_workers - for i, (data, reset_kwargs) in enumerate(zip(tdunbound, reset_kwargs_list)): + for i, (local_data, reset_kwargs) in enumerate( + zip(tensordict, reset_kwargs_list) + ): if not needs_resetting[i]: - out_tds[i] = tdunbound[i].exclude(*self.reset_keys) + localtd = local_data + if localtd is not None: + localtd = localtd.exclude(*self.reset_keys) + out_tds[i] = localtd continue - self.parent_channels[i].send(("reset", (data, reset_kwargs))) + self.parent_channels[i].send(("reset", (local_data, reset_kwargs))) for i, channel in enumerate(self.parent_channels): if not needs_resetting[i]: @@ -2129,6 +2133,8 @@ def _run_worker_pipe_direct( raise RuntimeError("call 'init' before resetting") # we use 'data' to pass the keys that we need to pass to reset, # because passing the entire buffer may have unwanted consequences + # data, idx, reset_kwargs = data + # data = data[idx] data, reset_kwargs = data if data is not None: data._fast_apply( @@ -2151,6 +2157,8 @@ def _run_worker_pipe_direct( if not initialized: raise RuntimeError("called 'init' before step") i += 1 + # data, idx = data + # data = data[idx] next_td = env._step(data) if event is not None: event.record() @@ -2165,8 +2173,8 @@ def _run_worker_pipe_direct( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - data, idx = data - data = data[idx] + # data, idx = data + # data = data[idx] data._fast_apply( lambda x: x.clone() if x.device.type == "cuda" else x, out=data )